mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
23 Commits
fix/gettin
...
fix/macos-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f6e916517 | ||
|
|
7a1c2f08b8 | ||
|
|
099c493b18 | ||
|
|
c1d1229ae0 | ||
|
|
94a36cb53e | ||
|
|
c7ba931466 | ||
|
|
413d95b740 | ||
|
|
332c624c55 | ||
|
|
dc160aff36 | ||
|
|
96806bf55f | ||
|
|
d33cd4c95b | ||
|
|
e2c2f64be7 | ||
|
|
cb73b94ffb | ||
|
|
1d920d700c | ||
|
|
bb85eee40a | ||
|
|
aba5d6f0d2 | ||
|
|
0588d2dbe1 | ||
|
|
14b3b77bda | ||
|
|
6da34e483c | ||
|
|
0efef671d7 | ||
|
|
435203b13b | ||
|
|
decb5dd3af | ||
|
|
28fbf96b2a |
@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
|
||||
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -286,6 +286,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -169,6 +169,14 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
@@ -346,6 +346,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -140,6 +140,17 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
manager.SetUDPPacketHook(
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,9 +18,7 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
drop bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -111,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -120,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -617,12 +619,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -641,12 +637,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addServiceParams(); err != nil {
|
||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serviceParamsFile = "service.json"
|
||||
serviceParamsBundle = "service_params.json"
|
||||
maskedValue = "***"
|
||||
envVarPrefix = "NB_"
|
||||
jsonKeyManagementURL = "management_url"
|
||||
jsonKeyServiceEnv = "service_env_vars"
|
||||
)
|
||||
|
||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||
|
||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||
func (g *BundleGenerator) addServiceParams() error {
|
||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read service params: %w", err)
|
||||
}
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return fmt.Errorf("parse service params: %w", err)
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||
}
|
||||
}
|
||||
|
||||
g.sanitizeServiceEnvVars(params)
|
||||
|
||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||
return fmt.Errorf("add service params to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(envVars))
|
||||
for k, v := range envVars {
|
||||
val, _ := v.(string)
|
||||
switch {
|
||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||
sanitized[k] = maskedValue
|
||||
case g.anonymize:
|
||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||
default:
|
||||
sanitized[k] = val
|
||||
}
|
||||
}
|
||||
params[jsonKeyServiceEnv] = sanitized
|
||||
}
|
||||
|
||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||
func isSensitiveEnvVar(key string) bool {
|
||||
lower := strings.ToLower(key)
|
||||
for _, s := range sensitiveEnvSubstrings {
|
||||
if strings.Contains(lower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -10,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
sensitive bool
|
||||
}{
|
||||
{"NB_SETUP_KEY", true},
|
||||
{"NB_API_TOKEN", true},
|
||||
{"NB_CLIENT_SECRET", true},
|
||||
{"NB_PASSWORD", true},
|
||||
{"NB_CREDENTIAL", true},
|
||||
{"NB_LOG_LEVEL", false},
|
||||
{"NB_MANAGEMENT_URL", false},
|
||||
{"NB_HOSTNAME", false},
|
||||
{"HOME", false},
|
||||
{"PATH", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anonymize bool
|
||||
input map[string]any
|
||||
check func(t *testing.T, params map[string]any)
|
||||
}{
|
||||
{
|
||||
name: "no env vars key",
|
||||
anonymize: false,
|
||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||
_, ok := params[jsonKeyServiceEnv]
|
||||
assert.False(t, ok, "service_env_vars should not be added")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sensitive NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safe NB vars anonymized when anonymize is true",
|
||||
anonymize: true,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
"NB_SETUP_KEY": "secret",
|
||||
"SOME_OTHER": "val",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||
|
||||
logVal := env["NB_LOG_LEVEL"].(string)
|
||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||
|
||||
// Sensitive and non-NB_ still masked
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
g := &BundleGenerator{
|
||||
anonymize: tt.anonymize,
|
||||
anonymizer: anonymizer,
|
||||
}
|
||||
g.sanitizeServiceEnvVars(tt.input)
|
||||
tt.check(t, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddServiceParams(t *testing.T) {
|
||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
}
|
||||
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = t.TempDir()
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
err := g.addServiceParams()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_LOG_LEVEL": "trace",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: true,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, zr.File, 1)
|
||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
mgmt := result[jsonKeyManagementURL].(string)
|
||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||
assert.NotEmpty(t, mgmt)
|
||||
|
||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||
})
|
||||
|
||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: false,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
|
||||
@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
|
||||
@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ type Server interface {
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,15 +4,25 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
Stop() error
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -31,25 +35,33 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
86
client/internal/dns/service_listener_test.go
Normal file
86
client/internal/dns/service_listener_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
|
||||
return &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
444
client/internal/dns/tcpstack.go
Normal file
444
client/internal/dns/tcpstack.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,10 +41,61 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
client.Net = "tcp"
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
@@ -210,9 +211,10 @@ type Engine struct {
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
@@ -259,26 +261,27 @@ func NewEngine(
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -500,7 +503,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
@@ -521,6 +524,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -532,6 +540,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
||||
}()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -1535,12 +1550,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
PortForwardManager: e.portForwardManager,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1697,6 +1713,12 @@ func (e *Engine) close() {
|
||||
if e.rpManager != nil {
|
||||
_ = e.rpManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -1800,7 +1822,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1837,6 +1859,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
PortForwardManager *portforward.Manager
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
@@ -87,16 +89,17 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
portForwardManager *portforward.Manager
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
portForwardManager: services.PortForwardManager,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -61,6 +62,9 @@ type WorkerICE struct {
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
|
||||
// portForwardAttempted tracks if we've already tried port forwarding this session
|
||||
portForwardAttempted bool
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
w.portForwardAttempted = false
|
||||
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if candidate.Type() == ice.CandidateTypeServerReflexive {
|
||||
w.injectPortForwardedCandidate(candidate)
|
||||
}
|
||||
}
|
||||
|
||||
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
|
||||
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
|
||||
pfManager := w.conn.portForwardManager
|
||||
if pfManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mapping := pfManager.GetMapping()
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
if w.portForwardAttempted {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.portForwardAttempted = true
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
|
||||
if err != nil {
|
||||
w.log.Warnf("create forwarded candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
|
||||
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
|
||||
|
||||
go func() {
|
||||
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
|
||||
w.log.Errorf("signal port-forwarded candidate: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
|
||||
// It uses the NAT gateway's external IP with the forwarded port.
|
||||
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
|
||||
var externalIP string
|
||||
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
|
||||
externalIP = mapping.ExternalIP.String()
|
||||
} else {
|
||||
// Fallback to STUN-discovered address if NAT didn't provide external IP
|
||||
externalIP = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Per RFC 8445, the related address for srflx is the base (host candidate address).
|
||||
// If the original srflx has unspecified related address, use its own address as base.
|
||||
relAddr := srflxCandidate.RelatedAddress().Address
|
||||
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
|
||||
relAddr = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
|
||||
// over regular srflx during ICE connectivity checks.
|
||||
priority := srflxCandidate.Priority() + 1000
|
||||
|
||||
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: srflxCandidate.NetworkType().String(),
|
||||
Address: externalIP,
|
||||
Port: int(mapping.ExternalPort),
|
||||
Component: srflxCandidate.Component(),
|
||||
Priority: priority,
|
||||
RelAddr: relAddr,
|
||||
RelPort: int(mapping.InternalPort),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create candidate: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range srflxCandidate.Extensions() {
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = srflxCandidate.ID()
|
||||
}
|
||||
if err := candidate.AddExtension(e); err != nil {
|
||||
return nil, fmt.Errorf("add extension: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return candidate, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
local.NetworkType(), local.Type(), local.Address(), local.Port(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
|
||||
26
client/internal/portforward/env.go
Normal file
26
client/internal/portforward/env.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
||||
)
|
||||
|
||||
func isDisabledByEnv() bool {
|
||||
val := os.Getenv(envDisableNATMapper)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
280
client/internal/portforward/manager.go
Normal file
280
client/internal/portforward/manager.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMappingTTL = 2 * time.Hour
|
||||
discoveryTimeout = 10 * time.Second
|
||||
mappingDescription = "NetBird"
|
||||
)
|
||||
|
||||
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
|
||||
// allowing for whitespace/newlines between tags from different router firmware.
|
||||
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
|
||||
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
|
||||
// which blocks startup for ~10s when no gateway is present (affects all clients).
|
||||
|
||||
type Manager struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
mapping *Mapping
|
||||
mappingLock sync.Mutex
|
||||
|
||||
wgPort uint16
|
||||
|
||||
done chan struct{}
|
||||
stopCtx chan context.Context
|
||||
|
||||
// protect exported functions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new port forwarding manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
stopCtx: make(chan context.Context, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
||||
m.mu.Lock()
|
||||
if m.cancel != nil {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if isDisabledByEnv() {
|
||||
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if wgPort == 0 {
|
||||
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.wgPort = wgPort
|
||||
|
||||
m.done = make(chan struct{})
|
||||
defer close(m.done)
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
m.mu.Unlock()
|
||||
|
||||
gateway, mapping, err := m.setup(ctx)
|
||||
if err != nil {
|
||||
log.Infof("port forwarding setup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mappingLock.Lock()
|
||||
m.mapping = mapping
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
m.renewLoop(ctx, gateway, mapping.TTL)
|
||||
|
||||
select {
|
||||
case cleanupCtx := <-m.stopCtx:
|
||||
// block the Start while cleaned up gracefully
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
default:
|
||||
// return Start immediately and cleanup in background
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
go func() {
|
||||
defer cleanupCancel()
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMapping returns the current mapping if ready, nil otherwise
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
m.mappingLock.Lock()
|
||||
defer m.mappingLock.Unlock()
|
||||
|
||||
if m.mapping == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapping := *m.mapping
|
||||
return &mapping
|
||||
}
|
||||
|
||||
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
||||
// After GracefullyStop returns, the manager cannot be restarted.
|
||||
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
||||
m.startTearDown(ctx)
|
||||
|
||||
m.cancel()
|
||||
m.cancel = nil
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||
defer discoverCancel()
|
||||
|
||||
gateway, err := nat.DiscoverGateway(discoverCtx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
||||
|
||||
mapping, err := m.createMapping(ctx, gateway)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create port mapping: %w", err)
|
||||
}
|
||||
return gateway, mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ttl := defaultMappingTTL
|
||||
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
if !isPermanentLeaseRequired(err) {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
|
||||
ttl = 0
|
||||
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
externalIP, err := gateway.GetExternalAddress()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get external address: %v", err)
|
||||
// todo return with err?
|
||||
}
|
||||
|
||||
mapping := &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: m.wgPort,
|
||||
ExternalPort: uint16(externalPort),
|
||||
ExternalIP: externalIP,
|
||||
NATType: gateway.Type(),
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
||||
m.wgPort, externalPort, gateway.Type(), externalIP)
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
||||
if ttl == 0 {
|
||||
// Permanent mappings don't expire, just wait for cancellation.
|
||||
<-ctx.Done()
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||
log.Warnf("failed to renew port mapping: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add port mapping: %w", err)
|
||||
}
|
||||
|
||||
if uint16(externalPort) != m.mapping.ExternalPort {
|
||||
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
||||
m.mappingLock.Lock()
|
||||
m.mapping.ExternalPort = uint16(externalPort)
|
||||
m.mappingLock.Unlock()
|
||||
}
|
||||
|
||||
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
||||
m.mappingLock.Lock()
|
||||
mapping := m.mapping
|
||||
m.mapping = nil
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
||||
log.Warnf("delete port mapping on stop: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
||||
}
|
||||
|
||||
func (m *Manager) startTearDown(ctx context.Context) {
|
||||
select {
|
||||
case m.stopCtx <- ctx:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
|
||||
func isPermanentLeaseRequired(err error) bool {
|
||||
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
|
||||
}
|
||||
39
client/internal/portforward/manager_js.go
Normal file
39
client/internal/portforward/manager_js.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a stub manager for js/wasm builds.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
|
||||
func (m *Manager) Start(context.Context, uint16) {
|
||||
// no NAT traversal in wasm
|
||||
}
|
||||
|
||||
// GracefullyStop is a no-op on js/wasm.
|
||||
func (m *Manager) GracefullyStop(context.Context) error { return nil }
|
||||
|
||||
// GetMapping always returns nil on js/wasm.
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
return nil
|
||||
}
|
||||
201
client/internal/portforward/manager_test.go
Normal file
201
client/internal/portforward/manager_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockNAT struct {
|
||||
natType string
|
||||
deviceAddr net.IP
|
||||
externalAddr net.IP
|
||||
internalAddr net.IP
|
||||
mappings map[int]int
|
||||
addMappingErr error
|
||||
deleteMappingErr error
|
||||
onlyPermanentLeases bool
|
||||
lastTimeout time.Duration
|
||||
}
|
||||
|
||||
func newMockNAT() *mockNAT {
|
||||
return &mockNAT{
|
||||
natType: "Mock-NAT",
|
||||
deviceAddr: net.ParseIP("192.168.1.1"),
|
||||
externalAddr: net.ParseIP("203.0.113.50"),
|
||||
internalAddr: net.ParseIP("192.168.1.100"),
|
||||
mappings: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNAT) Type() string {
|
||||
return m.natType
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
|
||||
return m.deviceAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
|
||||
return m.externalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
|
||||
return m.internalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
|
||||
if m.addMappingErr != nil {
|
||||
return 0, m.addMappingErr
|
||||
}
|
||||
if m.onlyPermanentLeases && timeout != 0 {
|
||||
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
|
||||
}
|
||||
externalPort := internalPort
|
||||
m.mappings[internalPort] = externalPort
|
||||
m.lastTimeout = timeout
|
||||
return externalPort, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
if m.deleteMappingErr != nil {
|
||||
return m.deleteMappingErr
|
||||
}
|
||||
delete(m.mappings, internalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManager_CreateMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, "udp", mapping.Protocol)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, uint16(51820), mapping.ExternalPort)
|
||||
assert.Equal(t, "Mock-NAT", mapping.NATType)
|
||||
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
|
||||
assert.Equal(t, defaultMappingTTL, mapping.TTL)
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
|
||||
m := NewManager()
|
||||
assert.Nil(t, m.GetMapping())
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
mapping := m.GetMapping()
|
||||
require.NotNil(t, mapping)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
|
||||
// Mutating the returned copy should not affect the manager's mapping.
|
||||
mapping.ExternalPort = 9999
|
||||
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
gateway := newMockNAT()
|
||||
// Seed the mock so we can verify deletion.
|
||||
gateway.mappings[51820] = 51820
|
||||
|
||||
m.cleanup(context.Background(), gateway)
|
||||
|
||||
_, exists := gateway.mappings[51820]
|
||||
assert.False(t, exists, "mapping should be deleted from gateway")
|
||||
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_NilMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
gateway := newMockNAT()
|
||||
|
||||
// Should not panic or call gateway.
|
||||
m.cleanup(context.Background(), gateway)
|
||||
}
|
||||
|
||||
|
||||
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
gateway.onlyPermanentLeases = true
|
||||
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
|
||||
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
|
||||
}
|
||||
|
||||
func TestIsPermanentLeaseRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UPnP error 725",
|
||||
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped error with 725",
|
||||
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error 725 with newlines in XML",
|
||||
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "bare 725 without XML tag",
|
||||
err: fmt.Errorf("error code 725"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unrelated error",
|
||||
err: fmt.Errorf("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
|
||||
// mgmProber is the subset of management client needed for URL migration probes.
|
||||
type mgmProber interface {
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
HealthCheck() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -777,8 +777,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
if err = client.HealthCheck(); err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -17,12 +17,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type mockMgmProber struct {
|
||||
key wgtypes.Key
|
||||
}
|
||||
type mockMgmProber struct{}
|
||||
|
||||
func (m *mockMgmProber) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
return &m.key, nil
|
||||
func (m *mockMgmProber) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMgmProber) Close() error { return nil }
|
||||
@@ -247,11 +245,7 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
key, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mockMgmProber{key: key.PublicKey()}, nil
|
||||
return &mockMgmProber{}, nil
|
||||
}
|
||||
t.Cleanup(func() { newMgmProber = origProber })
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ type Manager interface {
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||
return maps.Clone(m.clientRoutes)
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes returns only the currently selected/active client routes,
|
||||
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
|
||||
// if traffic should be routed through the tunnel.
|
||||
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -18,6 +18,7 @@ type MockManager struct {
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
if m.GetClientRoutesFunc != nil {
|
||||
return m.GetClientRoutesFunc()
|
||||
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
if m.GetSelectedClientRoutesFunc != nil {
|
||||
return m.GetSelectedClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
|
||||
@@ -105,7 +105,20 @@ func (r *SysOps) FlushMarkedRoutes() error {
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
if err := r.routeSocket(unix.RTM_ADD, prefix, nexthop); err != nil {
|
||||
if !errors.Is(err, unix.EEXIST) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Route already exists from a previous session that wasn't cleaned up properly
|
||||
// (e.g. macOS sleep/wake). Remove the stale route and retry.
|
||||
log.Infof("Route for %s already exists, replacing with new route", prefix)
|
||||
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
|
||||
log.Warnf("Failed to remove stale route for %s: %v", prefix, err)
|
||||
}
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
@@ -228,7 +241,7 @@ func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||
|
||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if rtMsg.Errno != 0 {
|
||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||
return fmt.Errorf("route socket: %w", syscall.Errno(rtMsg.Errno))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -1359,6 +1359,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
if engine.IsBlockInbound() {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
sshServer := s.sshServer
|
||||
if sshServer == nil {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
// Close outside the lock: session handlers need s.mu for unregisterSession.
|
||||
if err := sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
s.mu.Lock()
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.connections)
|
||||
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
|
||||
@@ -153,6 +153,9 @@ func networkAddresses() ([]NetworkAddress, error) {
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
if err != nil {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
return &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,9 +24,10 @@ import (
|
||||
|
||||
// Initial state for the debug collection
|
||||
type debugInitialState struct {
|
||||
wasDown bool
|
||||
logLevel proto.LogLevel
|
||||
isLevelTrace bool
|
||||
wasDown bool
|
||||
needsRestoreUp bool
|
||||
logLevel proto.LogLevel
|
||||
isLevelTrace bool
|
||||
}
|
||||
|
||||
// Debug collection parameters
|
||||
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
enablePersistence bool,
|
||||
) error {
|
||||
) {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service up: %v", err)
|
||||
log.Warnf("failed to bring service up: %v", err)
|
||||
} else {
|
||||
log.Info("Service brought up for debug")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
log.Info("Service brought up for debug")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
|
||||
return fmt.Errorf("set log level to TRACE: %v", err)
|
||||
log.Warnf("failed to set log level to TRACE: %v", err)
|
||||
} else {
|
||||
log.Info("Log level set to TRACE for debug")
|
||||
}
|
||||
log.Info("Log level set to TRACE for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service down: %v", err)
|
||||
log.Warnf("failed to bring service down: %v", err)
|
||||
} else {
|
||||
state.needsRestoreUp = !state.wasDown
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if enablePersistence {
|
||||
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enable sync response persistence: %v", err)
|
||||
log.Warnf("failed to enable sync response persistence: %v", err)
|
||||
} else {
|
||||
log.Info("Sync response persistence enabled for debug")
|
||||
}
|
||||
log.Info("Sync response persistence enabled for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service back up: %v", err)
|
||||
log.Warnf("failed to bring service back up: %v", err)
|
||||
} else {
|
||||
state.needsRestoreUp = false
|
||||
time.Sleep(time.Second * 3)
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to start CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) collectDebugData(
|
||||
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
||||
return err
|
||||
}
|
||||
s.configureServiceForDebug(conn, state, params.enablePersistence)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
||||
|
||||
// Restore service to original state
|
||||
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
|
||||
if state.needsRestoreUp {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Warnf("failed to restore up state: %v", err)
|
||||
} else {
|
||||
log.Info("Service state restored to up")
|
||||
}
|
||||
}
|
||||
|
||||
if state.wasDown {
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("Failed to restore down state: %v", err)
|
||||
log.Warnf("failed to restore down state: %v", err)
|
||||
} else {
|
||||
log.Info("Service state restored to down")
|
||||
}
|
||||
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
|
||||
log.Errorf("Failed to restore log level: %v", err)
|
||||
log.Warnf("failed to restore log level: %v", err)
|
||||
} else {
|
||||
log.Info("Log level restored to original setting")
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
relayServer "github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||
@@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn net.Conn)
|
||||
var relayAcceptFn func(conn listener.Conn)
|
||||
if relaySrv != nil {
|
||||
relayAcceptFn = relaySrv.RelayAccept()
|
||||
}
|
||||
@@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
||||
}
|
||||
|
||||
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
|
||||
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
|
||||
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
|
||||
acceptOptions := &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
}
|
||||
@@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
|
||||
if err != nil {
|
||||
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Relay WS client connected from: %s", rAddr)
|
||||
|
||||
conn := ws.NewConn(wsConn, lAddr, rAddr)
|
||||
conn := ws.NewConn(wsConn, rAddr)
|
||||
acceptFn(conn)
|
||||
}
|
||||
|
||||
|
||||
4
go.mod
4
go.mod
@@ -63,6 +63,7 @@ require (
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
@@ -200,10 +201,12 @@ require (
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/huin/goupnp v1.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
@@ -213,6 +216,7 @@ require (
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/koron/go-ssdp v0.0.4 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/libdns/libdns v0.2.2 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -281,6 +281,8 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY=
|
||||
github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -291,6 +293,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
|
||||
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
|
||||
@@ -328,6 +332,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
|
||||
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0=
|
||||
github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -346,6 +352,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
|
||||
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
|
||||
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
|
||||
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
|
||||
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
|
||||
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
|
||||
@@ -85,8 +85,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1119,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ type Manager interface {
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
|
||||
@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
|
||||
ret0, _ := ret[0].(*types.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks base method.
|
||||
|
||||
@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
|
||||
return groups, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type MockAccountManager struct {
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
|
||||
}
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if am.GetGroupByNameFunc != nil {
|
||||
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
||||
return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
@@ -2080,7 +2080,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
||||
FROM services WHERE account_id = $1`
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||
@@ -2097,6 +2098,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
var mode, source, sourcePeer sql.NullString
|
||||
var terminated sql.NullBool
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
&s.AccountID,
|
||||
@@ -2112,6 +2115,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
&s.RewriteRedirects,
|
||||
&sessionPrivateKey,
|
||||
&sessionPublicKey,
|
||||
&mode,
|
||||
&s.ListenPort,
|
||||
&s.PortAutoAssigned,
|
||||
&source,
|
||||
&sourcePeer,
|
||||
&terminated,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2143,7 +2152,18 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
if sessionPublicKey.Valid {
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
if mode.Valid {
|
||||
s.Mode = mode.String
|
||||
}
|
||||
if source.Valid {
|
||||
s.Source = source.String
|
||||
}
|
||||
if sourcePeer.Valid {
|
||||
s.SourcePeer = sourcePeer.String
|
||||
}
|
||||
if terminated.Valid {
|
||||
s.Terminated = terminated.Bool
|
||||
}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
|
||||
@@ -121,7 +121,7 @@ type Store interface {
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
||||
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||
|
||||
@@ -165,34 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain mocks base method.
|
||||
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
|
||||
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1389,6 +1361,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain mocks base method.
|
||||
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
|
||||
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetCustomDomain mocks base method.
|
||||
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1466,18 +1466,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) {
|
||||
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
|
||||
ret0, _ := ret[0].(*types2.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
|
||||
}
|
||||
|
||||
// GetGroupsByIDs mocks base method.
|
||||
@@ -1974,6 +1974,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks mocks base method.
|
||||
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
|
||||
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2361,21 +2376,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks mocks base method.
|
||||
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
|
||||
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPrimaryAccount mocks base method.
|
||||
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
217
management/server/types/networkmap_benchmark_test.go
Normal file
217
management/server/types/networkmap_benchmark_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type benchmarkScale struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}
|
||||
|
||||
var defaultScales = []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
{"10000peers_200groups", 10000, 200},
|
||||
{"20000peers_200groups", 20000, 200},
|
||||
{"30000peers_300groups", 30000, 300},
|
||||
}
|
||||
|
||||
func skipCIBenchmark(b *testing.B) {
|
||||
if os.Getenv("CI") == "true" {
|
||||
b.Skip("Skipping benchmark in CI")
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Single Peer Network Map Generation
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer.
|
||||
func BenchmarkNetworkMapGeneration_Components(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// All Peers (UpdateAccountPeers hot path)
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers.
|
||||
func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
scales := []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
}
|
||||
|
||||
for _, scale := range scales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
|
||||
peerIDs := make([]string, 0, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
|
||||
b.Run("components/"+scale.name, func(b *testing.B) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Sub-operations
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction.
|
||||
func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components.
|
||||
func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = types.CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs.
|
||||
func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetResourcePoliciesMap()
|
||||
}
|
||||
})
|
||||
b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetResourceRoutersMap()
|
||||
}
|
||||
})
|
||||
b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetActiveGroupUsers()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Scaling Analysis
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance.
|
||||
func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
groupCounts := []int{1, 5, 20, 50, 100, 200, 500}
|
||||
for _, numGroups := range groupCounts {
|
||||
b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(1000, numGroups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance.
|
||||
func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000}
|
||||
for _, numPeers := range peerCounts {
|
||||
numGroups := numPeers / 20
|
||||
if numGroups < 1 {
|
||||
numGroups = 1
|
||||
}
|
||||
b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(numPeers, numGroups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1192
management/server/types/networkmap_components_correctness_test.go
Normal file
1192
management/server/types/networkmap_components_correctness_test.go
Normal file
File diff suppressed because it is too large
Load Diff
260
proxy/web/package-lock.json
generated
260
proxy/web/package-lock.json
generated
@@ -15,7 +15,7 @@
|
||||
"tailwind-merge": "^2.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.39.1",
|
||||
"@eslint/js": "9.39.2",
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
"@types/node": "^24.10.1",
|
||||
"@types/react": "^19.2.5",
|
||||
@@ -29,7 +29,7 @@
|
||||
"tsx": "^4.21.0",
|
||||
"typescript": "~5.9.3",
|
||||
"typescript-eslint": "^8.46.4",
|
||||
"vite": "^7.2.4"
|
||||
"vite": "7.3.2"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/code-frame": {
|
||||
@@ -1024,9 +1024,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@rollup/rollup-android-arm-eabi": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.57.1.tgz",
|
||||
"integrity": "sha512-A6ehUVSiSaaliTxai040ZpZ2zTevHYbvu/lDoeAteHI8QnaosIzm4qwtezfRg1jOYaUmnzLX1AOD6Z+UJjtifg==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.0.tgz",
|
||||
"integrity": "sha512-WOhNW9K8bR3kf4zLxbfg6Pxu2ybOUbB2AjMDHSQx86LIF4rH4Ft7vmMwNt0loO0eonglSNy4cpD3MKXXKQu0/A==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
@@ -1038,9 +1038,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-android-arm64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.57.1.tgz",
|
||||
"integrity": "sha512-dQaAddCY9YgkFHZcFNS/606Exo8vcLHwArFZ7vxXq4rigo2bb494/xKMMwRRQW6ug7Js6yXmBZhSBRuBvCCQ3w==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.0.tgz",
|
||||
"integrity": "sha512-u6JHLll5QKRvjciE78bQXDmqRqNs5M/3GVqZeMwvmjaNODJih/WIrJlFVEihvV0MiYFmd+ZyPr9wxOVbPAG2Iw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1052,9 +1052,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-darwin-arm64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.57.1.tgz",
|
||||
"integrity": "sha512-crNPrwJOrRxagUYeMn/DZwqN88SDmwaJ8Cvi/TN1HnWBU7GwknckyosC2gd0IqYRsHDEnXf328o9/HC6OkPgOg==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.0.tgz",
|
||||
"integrity": "sha512-qEF7CsKKzSRc20Ciu2Zw1wRrBz4g56F7r/vRwY430UPp/nt1x21Q/fpJ9N5l47WWvJlkNCPJz3QRVw008fi7yA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1066,9 +1066,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-darwin-x64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.57.1.tgz",
|
||||
"integrity": "sha512-Ji8g8ChVbKrhFtig5QBV7iMaJrGtpHelkB3lsaKzadFBe58gmjfGXAOfI5FV0lYMH8wiqsxKQ1C9B0YTRXVy4w==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.0.tgz",
|
||||
"integrity": "sha512-WADYozJ4QCnXCH4wPB+3FuGmDPoFseVCUrANmA5LWwGmC6FL14BWC7pcq+FstOZv3baGX65tZ378uT6WG8ynTw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1080,9 +1080,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-freebsd-arm64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.57.1.tgz",
|
||||
"integrity": "sha512-R+/WwhsjmwodAcz65guCGFRkMb4gKWTcIeLy60JJQbXrJ97BOXHxnkPFrP+YwFlaS0m+uWJTstrUA9o+UchFug==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.0.tgz",
|
||||
"integrity": "sha512-6b8wGHJlDrGeSE3aH5mGNHBjA0TTkxdoNHik5EkvPHCt351XnigA4pS7Wsj/Eo9Y8RBU6f35cjN9SYmCFBtzxw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1094,9 +1094,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-freebsd-x64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.57.1.tgz",
|
||||
"integrity": "sha512-IEQTCHeiTOnAUC3IDQdzRAGj3jOAYNr9kBguI7MQAAZK3caezRrg0GxAb6Hchg4lxdZEI5Oq3iov/w/hnFWY9Q==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.0.tgz",
|
||||
"integrity": "sha512-h25Ga0t4jaylMB8M/JKAyrvvfxGRjnPQIR8lnCayyzEjEOx2EJIlIiMbhpWxDRKGKF8jbNH01NnN663dH638mA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1108,9 +1108,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-arm-gnueabihf": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.57.1.tgz",
|
||||
"integrity": "sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.0.tgz",
|
||||
"integrity": "sha512-RzeBwv0B3qtVBWtcuABtSuCzToo2IEAIQrcyB/b2zMvBWVbjo8bZDjACUpnaafaxhTw2W+imQbP2BD1usasK4g==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
@@ -1122,9 +1122,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-arm-musleabihf": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.57.1.tgz",
|
||||
"integrity": "sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.0.tgz",
|
||||
"integrity": "sha512-Sf7zusNI2CIU1HLzuu9Tc5YGAHEZs5Lu7N1ssJG4Tkw6e0MEsN7NdjUDDfGNHy2IU+ENyWT+L2obgWiguWibWQ==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
@@ -1136,9 +1136,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-arm64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-DX2x7CMcrJzsE91q7/O02IJQ5/aLkVtYFryqCjduJhUfGKG6yJV8hxaw8pZa93lLEpPTP/ohdN4wFz7yp/ry9A==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1150,9 +1150,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-arm64-musl": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.57.1.tgz",
|
||||
"integrity": "sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.0.tgz",
|
||||
"integrity": "sha512-09EL+yFVbJZlhcQfShpswwRZ0Rg+z/CsSELFCnPt3iK+iqwGsI4zht3secj5vLEs957QvFFXnzAT0FFPIxSrkQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1164,9 +1164,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-loong64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-i9IcCMPr3EXm8EQg5jnja0Zyc1iFxJjZWlb4wr7U2Wx/GrddOuEafxRdMPRYVaXjgbhvqalp6np07hN1w9kAKw==",
|
||||
"cpu": [
|
||||
"loong64"
|
||||
],
|
||||
@@ -1178,9 +1178,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-loong64-musl": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.57.1.tgz",
|
||||
"integrity": "sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.0.tgz",
|
||||
"integrity": "sha512-DGzdJK9kyJ+B78MCkWeGnpXJ91tK/iKA6HwHxF4TAlPIY7GXEvMe8hBFRgdrR9Ly4qebR/7gfUs9y2IoaVEyog==",
|
||||
"cpu": [
|
||||
"loong64"
|
||||
],
|
||||
@@ -1192,9 +1192,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-ppc64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-RwpnLsqC8qbS8z1H1AxBA1H6qknR4YpPR9w2XX0vo2Sz10miu57PkNcnHVaZkbqyw/kUWfKMI73jhmfi9BRMUQ==",
|
||||
"cpu": [
|
||||
"ppc64"
|
||||
],
|
||||
@@ -1206,9 +1206,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-ppc64-musl": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.57.1.tgz",
|
||||
"integrity": "sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.0.tgz",
|
||||
"integrity": "sha512-Z8pPf54Ly3aqtdWC3G4rFigZgNvd+qJlOE52fmko3KST9SoGfAdSRCwyoyG05q1HrrAblLbk1/PSIV+80/pxLg==",
|
||||
"cpu": [
|
||||
"ppc64"
|
||||
],
|
||||
@@ -1220,9 +1220,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-riscv64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-3a3qQustp3COCGvnP4SvrMHnPQ9d1vzCakQVRTliaz8cIp/wULGjiGpbcqrkv0WrHTEp8bQD/B3HBjzujVWLOA==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
@@ -1234,9 +1234,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-riscv64-musl": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.57.1.tgz",
|
||||
"integrity": "sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.0.tgz",
|
||||
"integrity": "sha512-pjZDsVH/1VsghMJ2/kAaxt6dL0psT6ZexQVrijczOf+PeP2BUqTHYejk3l6TlPRydggINOeNRhvpLa0AYpCWSQ==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
@@ -1248,9 +1248,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-s390x-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-3ObQs0BhvPgiUVZrN7gqCSvmFuMWvWvsjG5ayJ3Lraqv+2KhOsp+pUbigqbeWqueGIsnn+09HBw27rJ+gYK4VQ==",
|
||||
"cpu": [
|
||||
"s390x"
|
||||
],
|
||||
@@ -1262,9 +1262,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-x64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-EtylprDtQPdS5rXvAayrNDYoJhIz1/vzN2fEubo3yLE7tfAw+948dO0g4M0vkTVFhKojnF+n6C8bDNe+gDRdTg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1276,9 +1276,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-linux-x64-musl": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.57.1.tgz",
|
||||
"integrity": "sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.0.tgz",
|
||||
"integrity": "sha512-k09oiRCi/bHU9UVFqD17r3eJR9bn03TyKraCrlz5ULFJGdJGi7VOmm9jl44vOJvRJ6P7WuBi/s2A97LxxHGIdw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1290,9 +1290,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-openbsd-x64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.57.1.tgz",
|
||||
"integrity": "sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.0.tgz",
|
||||
"integrity": "sha512-1o/0/pIhozoSaDJoDcec+IVLbnRtQmHwPV730+AOD29lHEEo4F5BEUB24H0OBdhbBBDwIOSuf7vgg0Ywxdfiiw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1304,9 +1304,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-openharmony-arm64": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.57.1.tgz",
|
||||
"integrity": "sha512-4wYoDpNg6o/oPximyc/NG+mYUejZrCU2q+2w6YZqrAs2UcNUChIZXjtafAiiZSUc7On8v5NyNj34Kzj/Ltk6dQ==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.0.tgz",
|
||||
"integrity": "sha512-pESDkos/PDzYwtyzB5p/UoNU/8fJo68vcXM9ZW2V0kjYayj1KaaUfi1NmTUTUpMn4UhU4gTuK8gIaFO4UGuMbA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1318,9 +1318,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-win32-arm64-msvc": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.57.1.tgz",
|
||||
"integrity": "sha512-O54mtsV/6LW3P8qdTcamQmuC990HDfR71lo44oZMZlXU4tzLrbvTii87Ni9opq60ds0YzuAlEr/GNwuNluZyMQ==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.0.tgz",
|
||||
"integrity": "sha512-hj1wFStD7B1YBeYmvY+lWXZ7ey73YGPcViMShYikqKT1GtstIKQAtfUI6yrzPjAy/O7pO0VLXGmUVWXQMaYgTQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1332,9 +1332,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-win32-ia32-msvc": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.57.1.tgz",
|
||||
"integrity": "sha512-P3dLS+IerxCT/7D2q2FYcRdWRl22dNbrbBEtxdWhXrfIMPP9lQhb5h4Du04mdl5Woq05jVCDPCMF7Ub0NAjIew==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.0.tgz",
|
||||
"integrity": "sha512-SyaIPFoxmUPlNDq5EHkTbiKzmSEmq/gOYFI/3HHJ8iS/v1mbugVa7dXUzcJGQfoytp9DJFLhHH4U3/eTy2Bq4w==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
@@ -1346,9 +1346,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-win32-x64-gnu": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.57.1.tgz",
|
||||
"integrity": "sha512-VMBH2eOOaKGtIJYleXsi2B8CPVADrh+TyNxJ4mWPnKfLB/DBUmzW+5m1xUrcwWoMfSLagIRpjUFeW5CO5hyciQ==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.0.tgz",
|
||||
"integrity": "sha512-RdcryEfzZr+lAr5kRm2ucN9aVlCCa2QNq4hXelZxb8GG0NJSazq44Z3PCCc8wISRuCVnGs0lQJVX5Vp6fKA+IA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1360,9 +1360,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@rollup/rollup-win32-x64-msvc": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.57.1.tgz",
|
||||
"integrity": "sha512-mxRFDdHIWRxg3UfIIAwCm6NzvxG0jDX/wBN6KsQFTvKFqqg9vTrWUE68qEjHt19A5wwx5X5aUi2zuZT7YR0jrA==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.0.tgz",
|
||||
"integrity": "sha512-PrsWNQ8BuE00O3Xsx3ALh2Df8fAj9+cvvX9AIA6o4KpATR98c9mud4XtDWVvsEuyia5U4tVSTKygawyJkjm60w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1926,9 +1926,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
|
||||
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -1936,13 +1936,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
|
||||
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
|
||||
"version": "9.0.9",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
|
||||
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^2.0.1"
|
||||
"brace-expansion": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
@@ -2052,9 +2052,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "6.12.6",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
|
||||
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -2109,9 +2109,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
|
||||
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -2657,9 +2657,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/flatted": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
|
||||
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
|
||||
"version": "3.4.2",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
|
||||
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
@@ -3243,9 +3243,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
@@ -3386,9 +3386,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
@@ -3501,9 +3501,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/rollup": {
|
||||
"version": "4.57.1",
|
||||
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.57.1.tgz",
|
||||
"integrity": "sha512-oQL6lgK3e2QZeQ7gcgIkS2YZPg5slw37hYufJ3edKlfQSGGm8ICoxswK15ntSzF/a8+h7ekRy7k7oWc3BQ7y8A==",
|
||||
"version": "4.60.0",
|
||||
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.0.tgz",
|
||||
"integrity": "sha512-yqjxruMGBQJ2gG4HtjZtAfXArHomazDHoFwFFmZZl0r7Pdo7qCIXKqKHZc8yeoMgzJJ+pO6pEEHa+V7uzWlrAQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -3517,31 +3517,31 @@
|
||||
"npm": ">=8.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@rollup/rollup-android-arm-eabi": "4.57.1",
|
||||
"@rollup/rollup-android-arm64": "4.57.1",
|
||||
"@rollup/rollup-darwin-arm64": "4.57.1",
|
||||
"@rollup/rollup-darwin-x64": "4.57.1",
|
||||
"@rollup/rollup-freebsd-arm64": "4.57.1",
|
||||
"@rollup/rollup-freebsd-x64": "4.57.1",
|
||||
"@rollup/rollup-linux-arm-gnueabihf": "4.57.1",
|
||||
"@rollup/rollup-linux-arm-musleabihf": "4.57.1",
|
||||
"@rollup/rollup-linux-arm64-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-arm64-musl": "4.57.1",
|
||||
"@rollup/rollup-linux-loong64-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-loong64-musl": "4.57.1",
|
||||
"@rollup/rollup-linux-ppc64-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-ppc64-musl": "4.57.1",
|
||||
"@rollup/rollup-linux-riscv64-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-riscv64-musl": "4.57.1",
|
||||
"@rollup/rollup-linux-s390x-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-x64-gnu": "4.57.1",
|
||||
"@rollup/rollup-linux-x64-musl": "4.57.1",
|
||||
"@rollup/rollup-openbsd-x64": "4.57.1",
|
||||
"@rollup/rollup-openharmony-arm64": "4.57.1",
|
||||
"@rollup/rollup-win32-arm64-msvc": "4.57.1",
|
||||
"@rollup/rollup-win32-ia32-msvc": "4.57.1",
|
||||
"@rollup/rollup-win32-x64-gnu": "4.57.1",
|
||||
"@rollup/rollup-win32-x64-msvc": "4.57.1",
|
||||
"@rollup/rollup-android-arm-eabi": "4.60.0",
|
||||
"@rollup/rollup-android-arm64": "4.60.0",
|
||||
"@rollup/rollup-darwin-arm64": "4.60.0",
|
||||
"@rollup/rollup-darwin-x64": "4.60.0",
|
||||
"@rollup/rollup-freebsd-arm64": "4.60.0",
|
||||
"@rollup/rollup-freebsd-x64": "4.60.0",
|
||||
"@rollup/rollup-linux-arm-gnueabihf": "4.60.0",
|
||||
"@rollup/rollup-linux-arm-musleabihf": "4.60.0",
|
||||
"@rollup/rollup-linux-arm64-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-arm64-musl": "4.60.0",
|
||||
"@rollup/rollup-linux-loong64-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-loong64-musl": "4.60.0",
|
||||
"@rollup/rollup-linux-ppc64-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-ppc64-musl": "4.60.0",
|
||||
"@rollup/rollup-linux-riscv64-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-riscv64-musl": "4.60.0",
|
||||
"@rollup/rollup-linux-s390x-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-x64-gnu": "4.60.0",
|
||||
"@rollup/rollup-linux-x64-musl": "4.60.0",
|
||||
"@rollup/rollup-openbsd-x64": "4.60.0",
|
||||
"@rollup/rollup-openharmony-arm64": "4.60.0",
|
||||
"@rollup/rollup-win32-arm64-msvc": "4.60.0",
|
||||
"@rollup/rollup-win32-ia32-msvc": "4.60.0",
|
||||
"@rollup/rollup-win32-x64-gnu": "4.60.0",
|
||||
"@rollup/rollup-win32-x64-msvc": "4.60.0",
|
||||
"fsevents": "~2.3.2"
|
||||
}
|
||||
},
|
||||
@@ -3803,9 +3803,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "7.3.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
|
||||
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz",
|
||||
"integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"tailwind-merge": "^2.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.39.1",
|
||||
"@eslint/js": "9.39.2",
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
"@types/node": "^24.10.1",
|
||||
"@types/react": "^19.2.5",
|
||||
@@ -31,6 +31,6 @@
|
||||
"tsx": "^4.21.0",
|
||||
"typescript": "~5.9.3",
|
||||
"typescript-eslint": "^8.46.4",
|
||||
"vite": "^7.2.4"
|
||||
"vite": "7.3.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/shared/relay/messages/address"
|
||||
@@ -13,6 +15,12 @@ import (
|
||||
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
// handshakeTimeout bounds how long a connection may remain in the
|
||||
// pre-authentication handshake phase before being closed.
|
||||
handshakeTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Validator interface {
|
||||
Validate(any) error
|
||||
// Deprecated: Use Validate instead.
|
||||
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn net.Conn
|
||||
conn listener.Conn
|
||||
validator Validator
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
@@ -66,9 +74,9 @@ type handshake struct {
|
||||
peerID *messages.PeerID
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
||||
func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := h.conn.Read(buf)
|
||||
n, err := h.conn.Read(ctx, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeResponse() error {
|
||||
func (h *handshake) handshakeResponse(ctx context.Context) error {
|
||||
var responseMsg []byte
|
||||
if h.handshakeMethodAuth {
|
||||
responseMsg = h.preparedMsg.responseAuthMsg
|
||||
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
|
||||
responseMsg = h.preparedMsg.responseHelloMsg
|
||||
}
|
||||
|
||||
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||
if _, err := h.conn.Write(ctx, responseMsg); err != nil {
|
||||
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
|
||||
14
relay/server/listener/conn.go
Normal file
14
relay/server/listener/conn.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Conn is the relay connection contract implemented by WS and QUIC transports.
|
||||
type Conn interface {
|
||||
Read(ctx context.Context, b []byte) (n int, err error)
|
||||
Write(ctx context.Context, b []byte) (n int, err error)
|
||||
RemoteAddr() net.Addr
|
||||
Close() error
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/protocol"
|
||||
)
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Shutdown(ctx context.Context) error
|
||||
Protocol() protocol.Protocol
|
||||
}
|
||||
@@ -3,33 +3,26 @@ package quic
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
session *quic.Conn
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
session *quic.Conn
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(session *quic.Conn) *Conn {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Conn{
|
||||
session: session,
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
session: session,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
||||
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(ctx)
|
||||
if err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
@@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.session.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
if c.closed {
|
||||
@@ -74,8 +51,6 @@ func (c *Conn) Close() error {
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
|
||||
c.ctxCancel() // Cancel the context
|
||||
|
||||
sessionErr := c.session.CloseWithError(0, "normal closure")
|
||||
return sessionErr
|
||||
}
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/protocol"
|
||||
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
|
||||
nbRelay "github.com/netbirdio/netbird/shared/relay"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ type Listener struct {
|
||||
listener *quic.Listener
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
|
||||
quicCfg := &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
|
||||
@@ -18,25 +18,21 @@ const (
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
lAddr *net.TCPAddr
|
||||
rAddr *net.TCPAddr
|
||||
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
||||
func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
|
||||
return &Conn{
|
||||
Conn: wsConn,
|
||||
lAddr: lAddr,
|
||||
rAddr: rAddr,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.Reader(c.ctx)
|
||||
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||
t, r, err := c.Reader(ctx)
|
||||
if err != nil {
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
@@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
// Write writes a binary message with the given payload.
|
||||
// It does not block until fill the internal buffer.
|
||||
// If the buffer filled up, wait until the buffer is drained or timeout.
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
|
||||
func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
|
||||
ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
|
||||
defer ctxCancel()
|
||||
|
||||
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.lAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.rAddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/protocol"
|
||||
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/shared/relay"
|
||||
)
|
||||
|
||||
@@ -27,18 +29,19 @@ type Listener struct {
|
||||
TLSConfig *tls.Config
|
||||
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
acceptFn func(conn relaylistener.Conn)
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(URLPath, l.onAccept)
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.Address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.TLSConfig,
|
||||
Addr: l.Address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.TLSConfig,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
log.Infof("WS server listening address: %s", l.Address)
|
||||
@@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close ws connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("WS client connected from: %s", rAddr)
|
||||
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
conn := NewConn(wsConn, rAddr)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/store"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
@@ -26,11 +27,14 @@ type Peer struct {
|
||||
metrics *metrics.Metrics
|
||||
log *log.Entry
|
||||
id messages.PeerID
|
||||
conn net.Conn
|
||||
conn listener.Conn
|
||||
connMu sync.RWMutex
|
||||
store *store.Store
|
||||
notifier *store.PeerNotifier
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
peersListener *store.Listener
|
||||
|
||||
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
|
||||
@@ -38,14 +42,17 @@ type Peer struct {
|
||||
}
|
||||
|
||||
// NewPeer creates a new Peer instance and prepare custom logging
|
||||
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p := &Peer{
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", id.String()),
|
||||
id: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
notifier: notifier,
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", id.String()),
|
||||
id: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
notifier: notifier,
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
}
|
||||
|
||||
return p
|
||||
@@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store
|
||||
func (p *Peer) Work() {
|
||||
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
||||
defer func() {
|
||||
p.ctxCancel()
|
||||
p.notifier.RemoveListener(p.peersListener)
|
||||
|
||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
@@ -64,8 +72,7 @@ func (p *Peer) Work() {
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := p.ctx
|
||||
|
||||
hc := healthcheck.NewSender(p.log)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
@@ -73,7 +80,7 @@ func (p *Peer) Work() {
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
n, err := p.conn.Read(ctx, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
@@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
||||
}
|
||||
|
||||
// Write writes data to the connection
|
||||
func (p *Peer) Write(b []byte) (int, error) {
|
||||
func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
|
||||
p.connMu.RLock()
|
||||
defer p.connMu.RUnlock()
|
||||
return p.conn.Write(b)
|
||||
return p.conn.Write(ctx, b)
|
||||
}
|
||||
|
||||
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
|
||||
@@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
p.ctxCancel()
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
@@ -156,6 +164,7 @@ func (p *Peer) Close() {
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
p.ctxCancel()
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
@@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
writeDone := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
_, err = p.conn.Write(buf)
|
||||
close(writeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-writeDone:
|
||||
return err
|
||||
}
|
||||
_, err := p.conn.Write(ctx, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
|
||||
for {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
_, err := p.Write(messages.MarshalHealthcheck())
|
||||
_, err := p.Write(ctx, messages.MarshalHealthcheck())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send healthcheck message: %s", err)
|
||||
return
|
||||
@@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := dp.Write(msg)
|
||||
n, err := dp.Write(dp.ctx, msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
||||
return
|
||||
}
|
||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||
p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
|
||||
}
|
||||
|
||||
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||
@@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
||||
}
|
||||
|
||||
for n, msg := range msgs {
|
||||
if _, err := p.Write(msg); err != nil {
|
||||
if _, err := p.Write(p.ctx, msg); err != nil {
|
||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||
}
|
||||
}
|
||||
@@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||
}
|
||||
|
||||
for n, msg := range msgs {
|
||||
if _, err := p.Write(msg); err != nil {
|
||||
if _, err := p.Write(p.ctx, msg); err != nil {
|
||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,11 +12,20 @@ import (
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
|
||||
"github.com/netbirdio/netbird/relay/protocol"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
"github.com/netbirdio/netbird/relay/server/store"
|
||||
)
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn listener.Conn)) error
|
||||
Shutdown(ctx context.Context) error
|
||||
Protocol() protocol.Protocol
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Meter metric.Meter
|
||||
ExposedAddress string
|
||||
@@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) {
|
||||
}
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
func (r *Relay) Accept(conn listener.Conn) {
|
||||
acceptTime := time.Now()
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
@@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
|
||||
defer hsCancel()
|
||||
|
||||
h := handshake{
|
||||
conn: conn,
|
||||
validator: r.validator,
|
||||
preparedMsg: r.preparedMsg,
|
||||
}
|
||||
peerID, err := h.handshakeReceive()
|
||||
peerID, err := h.handshakeReceive(hsCtx)
|
||||
if err != nil {
|
||||
if peerid.IsHealthCheck(peerID) {
|
||||
log.Debugf("health check connection from %s", conn.RemoteAddr())
|
||||
@@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
|
||||
if err := h.handshakeResponse(); err != nil {
|
||||
if err := h.handshakeResponse(hsCtx); err != nil {
|
||||
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||
peer.Close()
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
@@ -31,7 +30,7 @@ type ListenerConfig struct {
|
||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||
type Server struct {
|
||||
relay *Relay
|
||||
listeners []listener.Listener
|
||||
listeners []Listener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
@@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) {
|
||||
}
|
||||
return &Server{
|
||||
relay: relay,
|
||||
listeners: make([]listener.Listener, 0, 2),
|
||||
listeners: make([]Listener, 0, 2),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
|
||||
wg := sync.WaitGroup{}
|
||||
for _, l := range r.listeners {
|
||||
wg.Add(1)
|
||||
go func(listener listener.Listener) {
|
||||
go func(listener Listener) {
|
||||
defer wg.Done()
|
||||
errChan <- listener.Listen(r.relay.Accept)
|
||||
}(l)
|
||||
@@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL {
|
||||
// RelayAccept returns the relay's Accept function for handling incoming connections.
|
||||
// This allows external HTTP handlers to route connections to the relay without
|
||||
// starting the relay's own listeners.
|
||||
func (r *Server) RelayAccept() func(conn net.Conn) {
|
||||
func (r *Server) RelayAccept() func(conn listener.Conn) {
|
||||
return r.relay.Accept
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -16,14 +14,18 @@ type Client interface {
|
||||
io.Closer
|
||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
GetServerURL() string
|
||||
// IsHealthy returns the current connection status without blocking.
|
||||
// Used by the engine to monitor connectivity in the background.
|
||||
IsHealthy() bool
|
||||
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||
// Used to validate connectivity before committing configuration changes.
|
||||
HealthCheck() error
|
||||
SyncMeta(sysInfo *system.Info) error
|
||||
Logout() error
|
||||
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)
|
||||
|
||||
@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetServerPublicKey(t *testing.T) {
|
||||
func TestClient_HealthCheck(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error("couldn't retrieve management public key")
|
||||
}
|
||||
if key == nil {
|
||||
t.Error("got an empty management public key")
|
||||
if err := client.HealthCheck(); err != nil {
|
||||
t.Errorf("health check failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sysInfo := system.GetInfo(context.TODO())
|
||||
_, err = client.Login(*key, sysInfo, nil, nil)
|
||||
_, err = client.Login(sysInfo, nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expecting err on unregistered login, got nil")
|
||||
}
|
||||
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
info := system.GetInfo(context.TODO())
|
||||
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
|
||||
resp, err := client.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
_, err = client.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
}
|
||||
|
||||
info = system.GetInfo(context.TODO())
|
||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
t.Fatalf("error while creating testClient: %v", err)
|
||||
}
|
||||
|
||||
key, err := testClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatalf("error while getting server public key from testclient, %v", err)
|
||||
}
|
||||
|
||||
var actualMeta *mgmtProto.PeerSystemMeta
|
||||
var actualValidKey string
|
||||
var wg sync.WaitGroup
|
||||
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
|
||||
_, err = testClient.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("error while trying to register client: %v", err)
|
||||
}
|
||||
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
t.Error("error while retrieving device auth flow information")
|
||||
}
|
||||
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -563,7 +541,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
|
||||
flowInfo, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
t.Error("error while retrieving pkce auth flow information")
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
|
||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
|
||||
// GetNetworkMap return with the network map
|
||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
return nil, err
|
||||
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||
// Used to validate connectivity before committing configuration changes.
|
||||
func (c *GrpcClient) HealthCheck() error {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
_, err := c.getServerPublicKey()
|
||||
return err
|
||||
}
|
||||
|
||||
// getServerPublicKey fetches the server's WireGuard public key.
|
||||
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, fmt.Errorf("failed while getting Management Service public key")
|
||||
return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
|
||||
}
|
||||
|
||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
return &serverKey, nil
|
||||
}
|
||||
|
||||
// IsHealthy probes the gRPC connection and returns false on errors
|
||||
// IsHealthy returns the current connection status without blocking.
|
||||
// Used by the engine to monitor connectivity in the background.
|
||||
func (c *GrpcClient) IsHealthy() bool {
|
||||
switch c.conn.GetState() {
|
||||
case connectivity.TransientFailure:
|
||||
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt login response: %s", err)
|
||||
return nil, err
|
||||
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||
// Takes care of encrypting and decrypting messages.
|
||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
||||
}
|
||||
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.DeviceAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
||||
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
||||
}
|
||||
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.PKCEAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Logout() error {
|
||||
serverKey, err := c.GetServerPublicKey()
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get server public key: %w", err)
|
||||
}
|
||||
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
|
||||
|
||||
// CreateExpose calls the management server to create a new expose service.
|
||||
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
|
||||
|
||||
// RenewExpose extends the TTL of an active expose session on the management server.
|
||||
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
||||
|
||||
// StopExpose terminates an active expose session on the management server.
|
||||
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -14,12 +12,12 @@ import (
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetServerURLFunc func() string
|
||||
HealthCheckFunc func() error
|
||||
SyncMetaFunc func(sysInfo *system.Info) error
|
||||
LogoutFunc func() error
|
||||
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
|
||||
return m.JobFunc(ctx, msgHandler)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if m.GetServerPublicKeyFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetServerPublicKeyFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.RegisterFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||
return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.LoginFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
|
||||
return m.LoginFunc(info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
||||
return m.GetDeviceAuthorizationFlowFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||
if m.GetPKCEAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetPKCEAuthorizationFlowFunc(serverKey)
|
||||
return m.GetPKCEAuthorizationFlowFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) HealthCheck() error {
|
||||
if m.HealthCheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.HealthCheckFunc()
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.
|
||||
|
||||
Reference in New Issue
Block a user