mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-12 19:59:56 +00:00
Compare commits
11 Commits
userspace-
...
v0.36.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4a6dafd27 | ||
|
|
a930c2aecf | ||
|
|
d48edb9837 | ||
|
|
b41de7fcd1 | ||
|
|
18f84f0df5 | ||
|
|
44407a158a | ||
|
|
488b697479 | ||
|
|
5953b43ead | ||
|
|
58b2eb4b92 | ||
|
|
05415f72ec | ||
|
|
b7af53ea40 |
@@ -9,6 +9,7 @@ USER netbird:netbird
|
|||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENV NB_USE_NETSTACK_MODE=true
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||||
ENV NB_CONFIG=config.json
|
ENV NB_CONFIG=config.json
|
||||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
ENV NB_DISABLE_DNS=true
|
ENV NB_DISABLE_DNS=true
|
||||||
|
|||||||
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var traceCmd = &cobra.Command{
|
||||||
|
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||||
|
Short: "Trace a packet through the firewall",
|
||||||
|
Example: `
|
||||||
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||||
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
|
Args: cobra.ExactArgs(3),
|
||||||
|
RunE: tracePacket,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
debugCmd.AddCommand(traceCmd)
|
||||||
|
|
||||||
|
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||||
|
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||||
|
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||||
|
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||||
|
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||||
|
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||||
|
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||||
|
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||||
|
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||||
|
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||||
|
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||||
|
direction := strings.ToLower(args[0])
|
||||||
|
if direction != "in" && direction != "out" {
|
||||||
|
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := cmd.Flag("protocol").Value.String()
|
||||||
|
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||||
|
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||||
|
}
|
||||||
|
|
||||||
|
sport, err := cmd.Flags().GetUint16("sport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid source port: %v", err)
|
||||||
|
}
|
||||||
|
dport, err := cmd.Flags().GetUint16("dport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||||
|
if protocol != "icmp" {
|
||||||
|
if sport == 0 {
|
||||||
|
sport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
if dport == 0 {
|
||||||
|
dport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpFlags *proto.TCPFlags
|
||||||
|
if protocol == "tcp" {
|
||||||
|
syn, _ := cmd.Flags().GetBool("syn")
|
||||||
|
ack, _ := cmd.Flags().GetBool("ack")
|
||||||
|
fin, _ := cmd.Flags().GetBool("fin")
|
||||||
|
rst, _ := cmd.Flags().GetBool("rst")
|
||||||
|
psh, _ := cmd.Flags().GetBool("psh")
|
||||||
|
urg, _ := cmd.Flags().GetBool("urg")
|
||||||
|
|
||||||
|
tcpFlags = &proto.TCPFlags{
|
||||||
|
Syn: syn,
|
||||||
|
Ack: ack,
|
||||||
|
Fin: fin,
|
||||||
|
Rst: rst,
|
||||||
|
Psh: psh,
|
||||||
|
Urg: urg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||||
|
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||||
|
SourceIp: args[1],
|
||||||
|
DestinationIp: args[2],
|
||||||
|
Protocol: protocol,
|
||||||
|
SourcePort: uint32(sport),
|
||||||
|
DestinationPort: uint32(dport),
|
||||||
|
Direction: direction,
|
||||||
|
TcpFlags: tcpFlags,
|
||||||
|
IcmpType: &icmpType,
|
||||||
|
IcmpCode: &icmpCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
|
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
|
for _, stage := range resp.Stages {
|
||||||
|
if stage.ForwardingDetails != nil {
|
||||||
|
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := map[bool]string{
|
||||||
|
true: "\033[32mALLOWED\033[0m", // Green
|
||||||
|
false: "\033[31mDENIED\033[0m", // Red
|
||||||
|
}[resp.FinalDisposition]
|
||||||
|
|
||||||
|
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||||
|
}
|
||||||
@@ -14,13 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface)
|
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, err := createNativeFirewall(iface, stateManager)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||||
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
@@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface)
|
fm, err := createFW(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
@@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,4 +12,6 @@ type IFaceMapper interface {
|
|||||||
Address() device.WGAddress
|
Address() device.WGAddress
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule := genRouteFilteringRuleSpec(params)
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
var err error
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// after the established rule
|
||||||
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||||
|
} else {
|
||||||
|
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add route rule: %v", err)
|
return nil, fmt.Errorf("add route rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,12 @@ type Manager interface {
|
|||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
|
SetLogLevel(log.Level)
|
||||||
|
|
||||||
|
EnableRouting() error
|
||||||
|
|
||||||
|
DisableRouting() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
@@ -318,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
//
|
//
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
@@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
stdout, stderr = runIptablesSave(t)
|
stdout, stderr = runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
|
t.Helper()
|
||||||
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||||
|
_, wantIsCounter := want[i].(*expr.Counter)
|
||||||
|
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
|
|||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
rule = r.conn.AddRule(rule)
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// TODO: Insert after the established rule
|
||||||
|
rule = r.conn.InsertRule(rule)
|
||||||
|
} else {
|
||||||
|
rule = r.conn.AddRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
|
|||||||
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
||||||
@@ -10,12 +10,11 @@ import (
|
|||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
SourceIP net.IP
|
||||||
DestIP net.IP
|
DestIP net.IP
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestPort uint16
|
DestPort uint16
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||||
established atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
|
||||||
func (b *BaseConnTrack) IsEstablished() bool {
|
|
||||||
return b.established.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
|
||||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
|
||||||
b.established.Store(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
|||||||
@@ -3,8 +3,14 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
func BenchmarkIPOperations(b *testing.B) {
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
@@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
func BenchmarkAtomicOperations(b *testing.B) {
|
|
||||||
conn := &BaseConnTrack{}
|
|
||||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.IsEstablished()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("SetEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.SetEstablished(i%2 == 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("GetLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.GetLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
@@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
|||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
type ICMPTracker struct {
|
type ICMPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ICMPConnKey]*ICMPConnTrack
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@@ -42,12 +45,13 @@ type ICMPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &ICMPTracker{
|
tracker := &ICMPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
@@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
|||||||
// TrackOutbound records an outbound ICMP Echo Request
|
// TrackOutbound records an outbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
|||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
Sequence: seq,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New ICMP connection %v", key)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
switch icmpType {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
|
||||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
|
||||||
return true
|
|
||||||
case uint8(layers.ICMPv4TypeEchoReply):
|
|
||||||
// continue processing
|
|
||||||
default:
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.ID == id &&
|
conn.ID == id &&
|
||||||
conn.Sequence == seq
|
conn.Sequence == seq
|
||||||
@@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -61,12 +64,24 @@ type TCPConnKey struct {
|
|||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
State TCPState
|
||||||
|
established atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (t *TCPConnTrack) IsEstablished() bool {
|
||||||
|
return t.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||||
|
t.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*TCPConnTrack
|
connections map[ConnKey]*TCPConnTrack
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@@ -76,8 +91,9 @@ type TCPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
@@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
|||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
// Create key before lock
|
// Create key before lock
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
},
|
},
|
||||||
State: TCPStateNew,
|
State: TCPStateNew,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(false)
|
conn.established.Store(false)
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
@@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, true)
|
t.updateState(conn, flags, true)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
@@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
conn.SetEstablished(false)
|
conn.SetEstablished(false)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
case TCPStateTimeWait:
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||||
// This is handled by the cleanup routine
|
// This is handled by the cleanup routine
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
|||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
|||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
type UDPTracker struct {
|
type UDPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*UDPConnTrack
|
connections map[ConnKey]*UDPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@@ -29,12 +32,13 @@ type UDPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
@@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
|||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New UDP connection: %v", conn)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
@@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.DestPort == srcPort &&
|
conn.DestPort == srcPort &&
|
||||||
conn.SourcePort == dstPort
|
conn.SourcePort == dstPort
|
||||||
@@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout)
|
tracker := NewUDPTracker(tt.timeout, logger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
|||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.True(t, conn.IsEstablished())
|
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1 * time.Second)
|
tracker := NewUDPTracker(1*time.Second, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
|||||||
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
|
type endpoint struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
device *wgdevice.Device
|
||||||
|
mtu uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
var written int
|
||||||
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
|
if data == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet through WireGuard
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Wait() {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultReceiveWindow = 32768
|
||||||
|
defaultMaxInFlight = 1024
|
||||||
|
iosReceiveWindow = 16384
|
||||||
|
iosMaxInFlight = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *endpoint
|
||||||
|
udpForwarder *udpForwarder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ip net.IP
|
||||||
|
netstack bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||||
|
s := stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
icmp.NewProtocol4,
|
||||||
|
},
|
||||||
|
HandleLocal: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
mtu, err := iface.GetDevice().MTU()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get MTU: %w", err)
|
||||||
|
}
|
||||||
|
nicID := tcpip.NICID(1)
|
||||||
|
endpoint := &endpoint{
|
||||||
|
logger: logger,
|
||||||
|
device: iface.GetWGDevice(),
|
||||||
|
mtu: uint32(mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := iface.Address().Network.Mask.Size()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
|
PrefixLen: ones,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
|
}
|
||||||
|
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetRouteTable([]tcpip.Route{
|
||||||
|
{
|
||||||
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &Forwarder{
|
||||||
|
logger: logger,
|
||||||
|
stack: s,
|
||||||
|
endpoint: endpoint,
|
||||||
|
udpForwarder: newUDPForwarder(mtu, logger),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
netstack: netstack,
|
||||||
|
ip: iface.Address().IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
receiveWindow := defaultReceiveWindow
|
||||||
|
maxInFlight := defaultMaxInFlight
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
receiveWindow = iosReceiveWindow
|
||||||
|
maxInFlight = iosMaxInFlight
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||||
|
|
||||||
|
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if f.endpoint.dispatcher != nil {
|
||||||
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the forwarder
|
||||||
|
func (f *Forwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
if f.udpForwarder != nil {
|
||||||
|
f.udpForwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stack.Close()
|
||||||
|
f.stack.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
|
return net.IPv4(127, 0, 0, 1)
|
||||||
|
}
|
||||||
|
return addr.AsSlice()
|
||||||
|
}
|
||||||
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleICMP handles ICMP packets from the network stack
|
||||||
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
lc := net.ListenConfig{}
|
||||||
|
// TODO: support non-root
|
||||||
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||||
|
|
||||||
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
|
// Get the complete ICMP message (header + data)
|
||||||
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
|
||||||
|
// For Echo Requests, send and handle response
|
||||||
|
switch icmpHdr.Type() {
|
||||||
|
case header.ICMPv4Echo:
|
||||||
|
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||||
|
case header.ICMPv4EchoReply:
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||||
|
_, err = conn.WriteTo(payload, dst)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||||
|
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, f.endpoint.mtu)
|
||||||
|
n, _, err := conn.ReadFrom(response)
|
||||||
|
if err != nil {
|
||||||
|
if !isTimeout(err) {
|
||||||
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
|
ip := header.IPv4(ipHdr)
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||||
|
SrcAddr: id.LocalAddress,
|
||||||
|
DstAddr: id.RemoteAddress,
|
||||||
|
})
|
||||||
|
ip.SetChecksum(^ip.CalculateChecksum())
|
||||||
|
|
||||||
|
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||||
|
fullPacket = append(fullPacket, ipHdr...)
|
||||||
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||||
|
return true
|
||||||
|
}
|
||||||
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
|
if err != nil {
|
||||||
|
r.Complete(true)
|
||||||
|
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
r.Complete(false)
|
||||||
|
|
||||||
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||||
|
|
||||||
|
go f.proxyTCP(id, inConn, outConn, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
ep.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create context for managing the proxy goroutines
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(outConn, inConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(inConn, outConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
288
client/firewall/uspfilter/forwarder/udp.go
Normal file
288
client/firewall/uspfilter/forwarder/udp.go
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
conn *gonet.UDPConn
|
||||||
|
outConn net.Conn
|
||||||
|
lastSeen atomic.Int64
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpForwarder struct {
|
||||||
|
sync.RWMutex
|
||||||
|
logger *nblog.Logger
|
||||||
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
|
bufPool sync.Pool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type idleConn struct {
|
||||||
|
id stack.TransportEndpointID
|
||||||
|
conn *udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &udpForwarder{
|
||||||
|
logger: logger,
|
||||||
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, mtu)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go f.cleanup()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the UDP forwarder and all active connections
|
||||||
|
func (f *udpForwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ep.Close()
|
||||||
|
delete(f.conns, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle UDP connections
|
||||||
|
func (f *udpForwarder) cleanup() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
var idleConns []idleConn
|
||||||
|
|
||||||
|
f.RLock()
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
if conn.getIdleDuration() > udpTimeout {
|
||||||
|
idleConns = append(idleConns, idleConn{id, conn})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.RUnlock()
|
||||||
|
|
||||||
|
for _, idle := range idleConns {
|
||||||
|
idle.conn.cancel()
|
||||||
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idle.conn.ep.Close()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
delete(f.conns, idle.id)
|
||||||
|
f.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
f.udpForwarder.RLock()
|
||||||
|
_, exists := f.udpForwarder.conns[id]
|
||||||
|
f.udpForwarder.RUnlock()
|
||||||
|
if exists {
|
||||||
|
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
// TODO: Send ICMP error message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
|
|
||||||
|
pConn := &udpPacketConn{
|
||||||
|
conn: inConn,
|
||||||
|
outConn: outConn,
|
||||||
|
cancel: connCancel,
|
||||||
|
ep: ep,
|
||||||
|
}
|
||||||
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
// Double-check no connection was created while we were setting up
|
||||||
|
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
pConn.cancel()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||||
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
pConn.cancel()
|
||||||
|
if err := pConn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, id)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||||
|
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||||
|
bufp := bufPool.Get().(*[]byte)
|
||||||
|
defer bufPool.Put(bufp)
|
||||||
|
buffer := *bufp
|
||||||
|
|
||||||
|
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if isTimeout(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read from %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.updateLastSeen()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedError(err error) bool {
|
||||||
|
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
134
client/firewall/uspfilter/localip.go
Normal file
134
client/firewall/uspfilter/localip.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type localIPManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||||
|
ipv4Bitmap [1 << 16]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalIPManager() *localIPManager {
|
||||||
|
return &localIPManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
if int(high) >= len(*newIPv4Bitmap) {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
ipStr := ip.String()
|
||||||
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
|
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
|
log.Debugf("process IP failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var newIPv4Bitmap [1 << 16]uint32
|
||||||
|
ipv4Set := make(map[string]struct{})
|
||||||
|
var ipv4Addresses []string
|
||||||
|
|
||||||
|
// 127.0.0.0/8
|
||||||
|
high := uint16(127) << 8
|
||||||
|
for i := uint16(0); i < 256; i++ {
|
||||||
|
newIPv4Bitmap[high|i] = 0xffffffff
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
return m.checkBitmapBit(ipv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
270
client/firewall/uspfilter/localip_test.go
Normal file
270
client/firewall/uspfilter/localip_test.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalIPManager(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupAddr iface.WGAddress
|
||||||
|
testIP net.IP
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Localhost range",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.2"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost standard address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost range edge",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.255.255.255"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP matches",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("fe80::1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("fe80::"),
|
||||||
|
Mask: net.CIDRMask(64, 128),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("fe80::1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
|
mock := &IFaceMock{
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return tt.setupAddr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := manager.IsLocalIP(tt.testIP)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
mock := &IFaceMock{}
|
||||||
|
|
||||||
|
// Get actual local interfaces
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tests []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all local interface IPs to test cases
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip4.String(),
|
||||||
|
expected: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some external IPs as negative test cases
|
||||||
|
externalIPs := []string{
|
||||||
|
"8.8.8.8",
|
||||||
|
"1.1.1.1",
|
||||||
|
"208.67.222.222",
|
||||||
|
}
|
||||||
|
for _, ip := range externalIPs {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip,
|
||||||
|
expected: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotEmpty(t, tests, "No test cases generated")
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||||
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap version
|
||||||
|
bitmapManager := &localIPManager{
|
||||||
|
ipv4Bitmap: [1 << 16]uint32{},
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
196
client/firewall/uspfilter/log/log.go
Normal file
196
client/firewall/uspfilter/log/log.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||||
|
maxMessageSize = 1024 * 2 // 2KB per message
|
||||||
|
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||||
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Level represents log severity
|
||||||
|
type Level uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelPanic Level = iota
|
||||||
|
LevelFatal
|
||||||
|
LevelError
|
||||||
|
LevelWarn
|
||||||
|
LevelInfo
|
||||||
|
LevelDebug
|
||||||
|
LevelTrace
|
||||||
|
)
|
||||||
|
|
||||||
|
var levelStrings = map[Level]string{
|
||||||
|
LevelPanic: "PANC",
|
||||||
|
LevelFatal: "FATL",
|
||||||
|
LevelError: "ERRO",
|
||||||
|
LevelWarn: "WARN",
|
||||||
|
LevelInfo: "INFO",
|
||||||
|
LevelDebug: "DEBG",
|
||||||
|
LevelTrace: "TRAC",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
buffer *ringBuffer
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Reusable buffer pool for formatting messages
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
|
l := &Logger{
|
||||||
|
output: logrusLogger.Out,
|
||||||
|
buffer: newRingBuffer(bufferSize),
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
// Pre-allocate buffer for message formatting
|
||||||
|
b := make([]byte, 0, maxMessageSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
|
l.level.Store(uint32(logrusLevel))
|
||||||
|
level := levelStrings[Level(logrusLevel)]
|
||||||
|
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||||
|
|
||||||
|
l.wg.Add(1)
|
||||||
|
go l.worker()
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) SetLevel(level Level) {
|
||||||
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
|
||||||
|
// Timestamp
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Level
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Message
|
||||||
|
if len(args) > 0 {
|
||||||
|
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||||
|
} else {
|
||||||
|
*buf = append(*buf, format...)
|
||||||
|
}
|
||||||
|
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
l.formatMessage(bufp, level, format, args...)
|
||||||
|
|
||||||
|
if len(*bufp) > maxMessageSize {
|
||||||
|
*bufp = (*bufp)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
_, _ = l.buffer.Write(*bufp)
|
||||||
|
|
||||||
|
l.bufPool.Put(bufp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
|
l.log(LevelError, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
|
l.log(LevelWarn, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
|
l.log(LevelInfo, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
|
l.log(LevelDebug, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
|
l.log(LevelTrace, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker periodically flushes the buffer
|
||||||
|
func (l *Logger) worker() {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buf := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.shutdown:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Read accumulated messages
|
||||||
|
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write batch
|
||||||
|
_, _ = l.output.Write(buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the logger
|
||||||
|
func (l *Logger) Stop(ctx context.Context) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
l.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// ringBuffer is a simple ring buffer implementation
|
||||||
|
type ringBuffer struct {
|
||||||
|
buf []byte
|
||||||
|
size int
|
||||||
|
r, w int64 // Read and write positions
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRingBuffer(size int) *ringBuffer {
|
||||||
|
return &ringBuffer{
|
||||||
|
buf: make([]byte, size),
|
||||||
|
size: size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(p) > r.size {
|
||||||
|
p = p[:r.size]
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(p)
|
||||||
|
|
||||||
|
// Write data, handling wrap-around
|
||||||
|
pos := int(r.w % int64(r.size))
|
||||||
|
writeLen := min(len(p), r.size-pos)
|
||||||
|
copy(r.buf[pos:], p[:writeLen])
|
||||||
|
|
||||||
|
// If we have more data and need to wrap around
|
||||||
|
if writeLen < len(p) {
|
||||||
|
copy(r.buf, p[writeLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update write position
|
||||||
|
r.w += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.w == r.r {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate available data accounting for wraparound
|
||||||
|
available := int(r.w - r.r)
|
||||||
|
if available < 0 {
|
||||||
|
available += r.size
|
||||||
|
}
|
||||||
|
available = min(available, r.size)
|
||||||
|
|
||||||
|
// Limit read to buffer size
|
||||||
|
toRead := min(available, len(p))
|
||||||
|
if toRead == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read data, handling wrap-around
|
||||||
|
pos := int(r.r % int64(r.size))
|
||||||
|
readLen := min(toRead, r.size-pos)
|
||||||
|
n = copy(p, r.buf[pos:pos+readLen])
|
||||||
|
|
||||||
|
// If we need more data and need to wrap around
|
||||||
|
if readLen < toRead {
|
||||||
|
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update read position
|
||||||
|
r.r += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
@@ -2,14 +2,15 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type Rule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
ip net.IP
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
@@ -24,6 +25,21 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *PeerRule) GetRuleID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteRule struct {
|
||||||
|
id string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
func (r *RouteRule) GetRuleID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
390
client/firewall/uspfilter/tracer.go
Normal file
390
client/firewall/uspfilter/tracer.go
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketStage int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StageReceived PacketStage = iota
|
||||||
|
StageConntrack
|
||||||
|
StagePeerACL
|
||||||
|
StageRouting
|
||||||
|
StageRouteACL
|
||||||
|
StageForwarding
|
||||||
|
StageCompleted
|
||||||
|
)
|
||||||
|
|
||||||
|
const msgProcessingCompleted = "Processing completed"
|
||||||
|
|
||||||
|
func (s PacketStage) String() string {
|
||||||
|
return map[PacketStage]string{
|
||||||
|
StageReceived: "Received",
|
||||||
|
StageConntrack: "Connection Tracking",
|
||||||
|
StagePeerACL: "Peer ACL",
|
||||||
|
StageRouting: "Routing",
|
||||||
|
StageRouteACL: "Route ACL",
|
||||||
|
StageForwarding: "Forwarding",
|
||||||
|
StageCompleted: "Completed",
|
||||||
|
}[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwarderAction struct {
|
||||||
|
Action string
|
||||||
|
RemoteAddr string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceResult struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
Stage PacketStage
|
||||||
|
Message string
|
||||||
|
Allowed bool
|
||||||
|
ForwarderAction *ForwarderAction
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketTrace struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
DestinationIP net.IP
|
||||||
|
Protocol string
|
||||||
|
SourcePort uint16
|
||||||
|
DestinationPort uint16
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
Results []TraceResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPState struct {
|
||||||
|
SYN bool
|
||||||
|
ACK bool
|
||||||
|
FIN bool
|
||||||
|
RST bool
|
||||||
|
PSH bool
|
||||||
|
URG bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketBuilder struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
Protocol fw.Protocol
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
PayloadSize int
|
||||||
|
TCPState *TCPState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
ForwarderAction: action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
|
ip := p.buildIPLayer()
|
||||||
|
pktLayers := []gopacket.SerializableLayer{ip}
|
||||||
|
|
||||||
|
transportLayer, err := p.buildTransportLayer(ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pktLayers = append(pktLayers, transportLayer...)
|
||||||
|
|
||||||
|
if p.PayloadSize > 0 {
|
||||||
|
payload := make([]byte, p.PayloadSize)
|
||||||
|
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializePacket(pktLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||||
|
return &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
|
SrcIP: p.SrcIP,
|
||||||
|
DstIP: p.DstIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
switch p.Protocol {
|
||||||
|
case "tcp":
|
||||||
|
return p.buildTCPLayer(ip)
|
||||||
|
case "udp":
|
||||||
|
return p.buildUDPLayer(ip)
|
||||||
|
case "icmp":
|
||||||
|
return p.buildICMPLayer()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
|
Window: 65535,
|
||||||
|
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||||
|
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||||
|
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||||
|
RST: p.TCPState != nil && p.TCPState.RST,
|
||||||
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
|
}
|
||||||
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
|
}
|
||||||
|
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||||
|
icmp.Id = uint16(1)
|
||||||
|
icmp.Seq = uint16(1)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||||
|
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||||
|
switch protocol {
|
||||||
|
case fw.ProtocolTCP:
|
||||||
|
return int(layers.IPProtocolTCP)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
return int(layers.IPProtocolUDP)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
return int(layers.IPProtocolICMPv4)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||||
|
packetData, err := builder.Build()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.TracePacket(packetData, builder.Direction), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
|
// Initial packet decoding
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base packet info
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
trace.SourceIP = srcIP
|
||||||
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
|
// Determine protocol and ports
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
trace.Protocol = "TCP"
|
||||||
|
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
trace.Protocol = "UDP"
|
||||||
|
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
trace.Protocol = "ICMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
|
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||||
|
|
||||||
|
if direction == fw.RuleDirectionOUT {
|
||||||
|
return m.traceOutbound(packetData, trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||||
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.handleRouting(trace) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.nativeRouter {
|
||||||
|
return m.handleNativeRouter(trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||||
|
msg := "No existing connection found"
|
||||||
|
if allowed {
|
||||||
|
msg = m.buildConntrackStateMessage(d)
|
||||||
|
trace.AddResult(StageConntrack, msg, true)
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trace.AddResult(StageConntrack, msg, false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||||
|
msg := "Matched existing connection state"
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||||
|
flags&conntrack.TCPSyn != 0,
|
||||||
|
flags&conntrack.TCPAck != 0,
|
||||||
|
flags&conntrack.TCPRst != 0,
|
||||||
|
flags&conntrack.TCPFin != 0)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
|
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
msg := "Allowed by peer ACL rules"
|
||||||
|
if blocked {
|
||||||
|
msg = "Blocked by peer ACL rules"
|
||||||
|
}
|
||||||
|
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||||
|
|
||||||
|
if m.netstack {
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
|
if !m.routingEnabled {
|
||||||
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||||
|
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||||
|
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
|
msg := "Allowed by route ACLs"
|
||||||
|
if !allowed {
|
||||||
|
msg = "Blocked by route ACLs"
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
|
if allowed && m.forwarder != nil {
|
||||||
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||||
|
fwdAction := &ForwarderAction{
|
||||||
|
Action: action,
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
}
|
||||||
|
trace.AddResultWithForwarder(StageForwarding,
|
||||||
|
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
|
// will create or update the connection state
|
||||||
|
dropped := m.processOutgoingHooks(packetData)
|
||||||
|
if dropped {
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
|
} else {
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
||||||
@@ -1,11 +1,14 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -14,28 +17,48 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
const (
|
||||||
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
|
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
var (
|
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||||
|
|
||||||
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||||
|
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||||
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]Rule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
|
type RouteRules []RouteRule
|
||||||
|
|
||||||
|
func (r RouteRules) Sort() {
|
||||||
|
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||||
|
// Deny rules come first
|
||||||
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.id, b.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -43,17 +66,34 @@ type Manager struct {
|
|||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[string]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
stateful bool
|
// indicates whether server routes are disabled
|
||||||
|
disableServerRoutes bool
|
||||||
|
// indicates whether we forward packets not destined for ourselves
|
||||||
|
routingEnabled bool
|
||||||
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
|
nativeRouter bool
|
||||||
|
// indicates whether we track outbound connections
|
||||||
|
stateful bool
|
||||||
|
// indicates whether wireguards runs in netstack mode
|
||||||
|
netstack bool
|
||||||
|
// indicates whether we forward local traffic to the native stack
|
||||||
|
localForwarding bool
|
||||||
|
|
||||||
|
localipmanager *localIPManager
|
||||||
|
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
|
forwarder *forwarder.Forwarder
|
||||||
|
logger *nblog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -70,22 +110,44 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||||
return create(iface)
|
return create(iface, nil, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
mgr, err := create(iface)
|
if nativeFirewall == nil {
|
||||||
|
return nil, errors.New("native firewall is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.nativeFirewall = nativeFirewall
|
|
||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func parseCreateEnv() (bool, bool) {
|
||||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
var disableConntrack, enableLocalForwarding bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||||
|
disableConntrack, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return disableConntrack, enableLocalForwarding
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
@@ -101,52 +163,183 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
outgoingRules: make(map[string]RuleSet),
|
nativeFirewall: nativeFirewall,
|
||||||
incomingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
incomingRules: make(map[string]RuleSet),
|
||||||
stateful: !disableConntrack,
|
wgIface: iface,
|
||||||
|
localipmanager: newLocalIPManager(),
|
||||||
|
disableServerRoutes: disableServerRoutes,
|
||||||
|
routingEnabled: false,
|
||||||
|
stateful: !disableConntrack,
|
||||||
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
netstack: netstack.IsEnabled(),
|
||||||
|
// default true for non-netstack, for netstack only if explicitly enabled
|
||||||
|
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// netstack needs the forwarder for local traffic
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
if err := m.initForwarder(); err != nil {
|
||||||
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.blockInvalidRouted(iface); err != nil {
|
||||||
|
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse wireguard network: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
|
if _, err := m.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
|
wgPrefix,
|
||||||
|
firewall.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.ActionDrop,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("block wg nte : %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) determineRouting() error {
|
||||||
|
var disableUspRouting, forceUserspaceRouter bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||||
|
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case disableUspRouting:
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
|
case m.disableServerRoutes:
|
||||||
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
|
case forceUserspaceRouter:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
|
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||||
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
|
default:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing enabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.routingEnabled && !m.nativeRouter {
|
||||||
|
return m.initForwarder()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
|
func (m *Manager) initForwarder() error {
|
||||||
|
if m.forwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
|
intf := m.wgIface.GetWGDevice()
|
||||||
|
if intf == nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return errors.New("forwarding not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||||
|
if err != nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder = forwarder
|
||||||
|
|
||||||
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Init(*statemanager.Manager) error {
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if m.nativeFirewall == nil {
|
return true
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
|
||||||
|
// userspace routed packets are always SNATed to the inbound direction
|
||||||
|
// TODO: implement outbound SNAT
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
@@ -162,7 +355,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
_ string,
|
_ string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
@@ -205,18 +398,56 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
if m.nativeFirewall == nil {
|
sources []netip.Prefix,
|
||||||
return nil, errRouteNotSupported
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := uuid.New().String()
|
||||||
|
rule := RouteRule{
|
||||||
|
id: ruleID,
|
||||||
|
sources: sources,
|
||||||
|
destination: destination,
|
||||||
|
proto: proto,
|
||||||
|
srcPort: sPort,
|
||||||
|
dstPort: dPort,
|
||||||
|
action: action,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routeRules = append(m.routeRules, rule)
|
||||||
|
m.routeRules.Sort()
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := rule.GetRuleID()
|
||||||
|
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||||
|
return r.id == ruleID
|
||||||
|
})
|
||||||
|
if idx < 0 {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -224,7 +455,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*PeerRule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
@@ -255,10 +486,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
|||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.incomingRules)
|
return m.dropFilter(packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
|
func (m *Manager) UpdateLocalIPs() error {
|
||||||
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
@@ -279,18 +514,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always process UDP hooks
|
// Track all protocols if stateful mode is enabled
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
|
||||||
// Track UDP state only if enabled
|
|
||||||
if m.stateful {
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track other protocols only if stateful mode is enabled
|
|
||||||
if m.stateful {
|
if m.stateful {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@@ -298,6 +526,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process UDP hooks even if stateful mode is disabled
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,10 +612,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
// If it returns true, the packet should be dropped.
|
||||||
// TODO: Disable router if --disable-server-router is set
|
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@@ -395,39 +627,127 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if srcIP == nil {
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
return false
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection state only if enabled
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.applyRules(srcIP, packetData, rules, d)
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalTraffic handles local traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
|
if m.netstack {
|
||||||
|
m.handleNetstackLocalTraffic(packetData)
|
||||||
|
|
||||||
|
// don't process this packet further
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRoutedTraffic handles routed traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
// Drop if routing is disabled
|
||||||
|
if !m.routingEnabled {
|
||||||
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass to native stack if native router is enabled or forced
|
||||||
|
if m.nativeRouter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||||
|
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||||
|
srcIP, srcPort, dstIP, dstPort, proto)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return firewall.ProtocolTCP
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return firewall.ProtocolUDP
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return firewall.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return firewall.ProtocolALL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||||
|
default:
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
|
||||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
@@ -462,7 +782,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||||
|
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||||
|
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType := d.icmp4.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
|
if m.isSpecialICMP(d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||||
return filter
|
return filter
|
||||||
}
|
}
|
||||||
@@ -496,7 +831,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
@@ -533,6 +868,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||||
|
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||||
|
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||||
|
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||||
|
return rule.action == firewall.ActionAccept
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
if !rule.destination.Contains(dstAddr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceMatched := false
|
||||||
|
for _, src := range rule.sources {
|
||||||
|
if src.Contains(srcAddr) {
|
||||||
|
sourceMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sourceMatched {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||||
|
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
m.wgNetwork = network
|
m.wgNetwork = network
|
||||||
@@ -544,7 +924,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||||
) string {
|
) string {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
@@ -561,12 +941,12 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
@@ -599,3 +979,41 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(level log.Level) {
|
||||||
|
if m.logger != nil {
|
||||||
|
m.logger.SetLevel(nblog.Level(level))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.determineRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
// don't stop forwarder if in use by netstack
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder.Stop()
|
||||||
|
m.forwarder = nil
|
||||||
|
|
||||||
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
|
//go:build uspbench
|
||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, manager.incomingRules)
|
manager.dropFilter(testIn)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack)
|
||||||
@@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
@@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
@@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
|||||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouteACLs(b *testing.B) {
|
||||||
|
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||||
|
|
||||||
|
// Add several route rules to simulate real-world scenario
|
||||||
|
rules := []struct {
|
||||||
|
sources []netip.Prefix
|
||||||
|
dest netip.Prefix
|
||||||
|
proto fw.Protocol
|
||||||
|
port *fw.Port
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
port: &fw.Port{Values: []uint16{80, 443}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/12"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
port: &fw.Port{Values: []uint16{53}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
_, err := manager.AddRouteFiltering(
|
||||||
|
r.sources,
|
||||||
|
r.dest,
|
||||||
|
r.proto,
|
||||||
|
nil,
|
||||||
|
r.port,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases that exercise different matching scenarios
|
||||||
|
cases := []struct {
|
||||||
|
srcIP string
|
||||||
|
dstIP string
|
||||||
|
proto fw.Protocol
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||||
|
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||||
|
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||||
|
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for _, tc := range cases {
|
||||||
|
srcIP := net.ParseIP(tc.srcIP)
|
||||||
|
dstIP := net.ParseIP(tc.dstIP)
|
||||||
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,17 +9,38 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() iface.WGAddress
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
|
if i.GetWGDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||||
|
if i.GetDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||||
@@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
@@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.incomingRules) {
|
if m.dropFilter(buf.Bytes()) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface)
|
manager, err := Create(iface, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Reset(nil))
|
||||||
}()
|
}()
|
||||||
@@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(testBuf.Bytes())
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@@ -15,4 +17,5 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -151,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device, not applicable for kernel devices
|
||||||
|
func (t *TunKernelDevice) Device() *device.Device {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
|
|||||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunNetstackDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *USPDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *USPDevice) assignAddr() error {
|
func (t *USPDevice) assignAddr() error {
|
||||||
link := newWGLink(t.name)
|
link := newWGLink(t.name)
|
||||||
|
|||||||
@@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||||
if t.nativeTunDevice == nil {
|
if t.nativeTunDevice == nil {
|
||||||
return "", fmt.Errorf("interface has not been initialized yet")
|
return "", fmt.Errorf("interface has not been initialized yet")
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@@ -13,4 +15,5 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return w.tun.FilteredDevice()
|
return w.tun.FilteredDevice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWGDevice returns the WireGuard device
|
||||||
|
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return w.tun.Device()
|
||||||
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats(peerKey)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@@ -29,6 +30,7 @@ type MockWGIface struct {
|
|||||||
SetFilterFunc func(filter device.PacketFilter) error
|
SetFilterFunc func(filter device.PacketFilter) error
|
||||||
GetFilterFunc func() device.PacketFilter
|
GetFilterFunc func() device.PacketFilter
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
@@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return m.GetDeviceFunc()
|
return m.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return m.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return m.GetStatsFunc(peerKey)
|
return m.GetStatsFunc(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||||
//TODO implement me
|
return m.GetProxyFunc()
|
||||||
panic("implement me")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@@ -32,5 +33,6 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@@ -30,6 +31,7 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
iface "github.com/netbirdio/netbird/client/iface"
|
iface "github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetDevice")
|
||||||
|
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevice indicates an expected call of GetDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||||
|
ret0, _ := ret[0].(*wgdevice.Device)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -109,6 +110,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
|
nbnet.Init()
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityDNSRoute = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityMatchDomain = 50
|
||||||
PriorityDefault = 0
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type SubdomainMatcher interface {
|
type SubdomainMatcher interface {
|
||||||
@@ -26,7 +26,6 @@ type HandlerEntry struct {
|
|||||||
Pattern string
|
Pattern string
|
||||||
OrigPattern string
|
OrigPattern string
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
StopHandler handlerWithStop
|
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
|
||||||
c.handlers[i].StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
OrigPattern: origPattern,
|
OrigPattern: origPattern,
|
||||||
IsWildcard: isWildcard,
|
IsWildcard: isWildcard,
|
||||||
StopHandler: stopHandler,
|
|
||||||
MatchSubdomains: matchSubdomains,
|
MatchSubdomains: matchSubdomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
|
||||||
entry.StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !matched {
|
if !matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
dnsRouteHandler := &nbdns.MockHandler{}
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
pattern = "*." + tt.handlerDomain[2:]
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
@@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
@@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3 := &nbdns.MockHandler{}
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers[op.priority] = handler
|
handlers[op.priority] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
@@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state with all three handlers
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
@@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
handler = mockHandler
|
handler = mockHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
chain.AddHandler(pattern, handler, h.priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
@@ -795,7 +795,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||||
handlers[op.pattern] = handler
|
handlers[op.pattern] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,51 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||||
|
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||||
|
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
|
||||||
|
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
|
||||||
|
|
||||||
dnsPolicyConfigVersionKey = "Version"
|
dnsPolicyConfigVersionKey = "Version"
|
||||||
dnsPolicyConfigVersionValue = 2
|
dnsPolicyConfigVersionValue = 2
|
||||||
dnsPolicyConfigNameKey = "Name"
|
dnsPolicyConfigNameKey = "Name"
|
||||||
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
||||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
|
||||||
|
// RP_FORCE: Reapply all policies even if no policy change was detected
|
||||||
|
rpForce = 0x1
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryConfigurator struct {
|
type registryConfigurator struct {
|
||||||
guid string
|
guid string
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
gpo bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@@ -37,12 +53,20 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newHostManagerWithGuid(guid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
var useGPO bool
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, gpoDnsPolicyRoot, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to open GPO DNS policy root: %v", err)
|
||||||
|
} else {
|
||||||
|
closer(k)
|
||||||
|
useGPO = true
|
||||||
|
log.Infof("detected GPO DNS policy configuration, using policy store")
|
||||||
|
}
|
||||||
|
|
||||||
return ®istryConfigurator{
|
return ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
|
gpo: useGPO,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,30 +75,23 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add dns setup: %w", err)
|
return fmt.Errorf("add dns setup: %w", err)
|
||||||
}
|
}
|
||||||
} else if r.routingAll {
|
} else if r.routingAll {
|
||||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = false
|
r.routingAll = false
|
||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var searchDomains, matchDomains []string
|
||||||
searchDomains []string
|
|
||||||
matchDomains []string
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
if dConf.Disabled {
|
if dConf.Disabled {
|
||||||
continue
|
continue
|
||||||
@@ -86,16 +103,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
|
||||||
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
}
|
return fmt.Errorf("remove dns match policies: %w", err)
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("add dns match policy: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.updateSearchDomains(searchDomains)
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,9 +120,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
|
||||||
}
|
}
|
||||||
r.routingAll = true
|
r.routingAll = true
|
||||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||||
@@ -113,64 +129,54 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
||||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
if err == nil {
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
policyPath := dnsPolicyConfigMatchPath
|
||||||
if err != nil {
|
if r.gpo {
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
policyPath = gpoDnsPolicyConfigMatchPath
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||||
|
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||||
|
}
|
||||||
|
defer closer(regKey)
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigVersionKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetStringsValue(dnsPolicyConfigNameKey, domains); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.gpo {
|
||||||
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
|
log.Warnf("failed to refresh group policy: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreHostDNS() error {
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
|
||||||
log.Errorf("remove registry key from dns policy config: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
|
||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
return fmt.Errorf("adding search domain failed with error: %w", err)
|
|
||||||
}
|
}
|
||||||
|
log.Infof("updated search domains: %s", domains)
|
||||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,11 +187,9 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
|||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.SetStringValue(key, value)
|
if err := regKey.SetStringValue(key, value); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("set key %s=%s: %w", key, value, err)
|
||||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,43 +200,91 @@ func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey st
|
|||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.DeleteValue(propertyKey)
|
if err := regKey.DeleteValue(propertyKey); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("delete registry key %s: %w", propertyKey, err)
|
||||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||||
var regKey registry.Key
|
|
||||||
|
|
||||||
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
||||||
|
|
||||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
return regKey, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return regKey, nil
|
return regKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
func (r *registryConfigurator) restoreHostDNS() error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
log.Errorf("remove dns match policies: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||||
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("refresh group policy: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
|
return r.restoreHostDNS()
|
||||||
|
}
|
||||||
|
|
||||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
defer closer(k)
|
log.Debugf("failed to open HKEY_LOCAL_MACHINE\\%s: %v", regKeyPath, err)
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
closer(k)
|
||||||
|
if err := registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath); err != nil {
|
||||||
|
return fmt.Errorf("delete HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshGroupPolicy() error {
|
||||||
|
// refreshPolicyExFn.Call() panics if the func is not found
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Errorf("Recovered from panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ret, _, err := refreshPolicyExFn.Call(
|
||||||
|
// bMachine = TRUE (computer policy)
|
||||||
|
uintptr(1),
|
||||||
|
// dwOptions = RP_FORCE
|
||||||
|
uintptr(rpForce),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret == 0 {
|
||||||
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
|
return fmt.Errorf("RefreshPolicyEx failed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("RefreshPolicyEx failed")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,10 +29,15 @@ func (d *localResolver) String() string {
|
|||||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (d *localResolver) id() handlerID {
|
||||||
|
return "local-resolver"
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) > 0 {
|
if len(r.Question) > 0 {
|
||||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -42,7 +41,12 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type handlerID string
|
||||||
|
|
||||||
|
type nsGroupsByDomain struct {
|
||||||
|
domain string
|
||||||
|
groups []*nbdns.NameServerGroup
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
@@ -52,7 +56,6 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
handlerPriorities map[string]int
|
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
@@ -77,14 +80,17 @@ type handlerWithStop interface {
|
|||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
stop()
|
||||||
probeAvailability()
|
probeAvailability()
|
||||||
|
id() handlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type handlerWrapper struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -158,13 +164,12 @@ func newDefaultServer(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
handlerPriorities: make(map[string]int),
|
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
log.Warn("skipping empty domain")
|
log.Warn("skipping empty domain")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.handlerPriorities[domain] = priority
|
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
// Only deregister from service if no handlers remain
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
if domain == "" {
|
|
||||||
log.Warn("skipping empty domain")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
|
|||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
// Check if there's any root handler
|
||||||
if ok {
|
var hasRootHandler bool
|
||||||
|
for _, handler := range s.dnsMuxMap {
|
||||||
|
if handler.domain == nbdns.RootZone {
|
||||||
|
hasRootHandler = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasRootHandler {
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
@@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
go func(mux handlerWithStop) {
|
go func(mux handlerWithStop) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
mux.probeAvailability()
|
mux.probeAvailability()
|
||||||
}(mux)
|
}(mux.handler)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
@@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
|
||||||
var muxUpdates []muxUpdate
|
var muxUpdates []handlerWrapper
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
@@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return muxUpdates, localRecords, nil
|
return muxUpdates, localRecords, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
for _, nsGroup := range nameServerGroups {
|
||||||
if len(nsGroup.NameServers) == 0 {
|
if len(nsGroup.NameServers) == 0 {
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||||
|
|
||||||
|
for _, domainGroup := range groupedNS {
|
||||||
|
basePriority := PriorityMatchDomain
|
||||||
|
if domainGroup.domain == nbdns.RootZone {
|
||||||
|
basePriority = PriorityDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, updates...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
|
for i, nsGroup := range domainGroup.groups {
|
||||||
|
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||||
|
priority := basePriority - i
|
||||||
|
|
||||||
|
// Check if we're about to overlap with the next priority tier
|
||||||
|
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||||
|
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||||
|
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface.Name(),
|
||||||
@@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
domainGroup.domain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||||
@@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
|
||||||
|
|
||||||
if nsGroup.Primary {
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
domain: domainGroup.domain,
|
||||||
domain: nbdns.RootZone,
|
handler: handler,
|
||||||
handler: handler,
|
priority: priority,
|
||||||
priority: PriorityDefault,
|
})
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
priority: PriorityMatchDomain,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return muxUpdates, nil
|
return muxUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
handlersByPriority := make(map[string]int)
|
for _, existing := range s.dnsMuxMap {
|
||||||
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
var isContainRootUpdate bool
|
existing.handler.stop()
|
||||||
|
|
||||||
// First register new handlers
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
handlersByPriority[update.domain] = update.priority
|
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.domain == nbdns.RootZone {
|
|
||||||
isContainRootUpdate = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then deregister old handlers not in the update
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
var containsRootUpdate bool
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
for _, update := range muxUpdates {
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
if update.domain == nbdns.RootZone {
|
||||||
|
containsRootUpdate = true
|
||||||
|
}
|
||||||
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
|
muxUpdateMap[update.handler.id()] = update
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's no root update and we had a root handler, restore it
|
||||||
|
if !containsRootUpdate {
|
||||||
|
for _, existing := range s.dnsMuxMap {
|
||||||
|
if existing.domain == nbdns.RootZone {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
existingHandler.stop()
|
break
|
||||||
} else {
|
|
||||||
existingHandler.stop()
|
|
||||||
// Deregister with the priority that was used to register
|
|
||||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
|
||||||
s.deregisterHandler([]string{key}, oldPriority)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
s.handlerPriorities = handlersByPriority
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
func (s *DefaultServer) upstreamCallbacks(
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
nsGroup *nbdns.NameServerGroup,
|
nsGroup *nbdns.NameServerGroup,
|
||||||
handler dns.Handler,
|
handler dns.Handler,
|
||||||
|
priority int,
|
||||||
) (deactivate func(error), reactivate func()) {
|
) (deactivate func(error), reactivate func()) {
|
||||||
var removeIndex map[string]int
|
var removeIndex map[string]int
|
||||||
deactivate = func(err error) {
|
deactivate = func(err error) {
|
||||||
@@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
s.deregisterHandler([]string{item.Domain}, priority)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, err, false)
|
s.updateNSState(nsGroup, err, false)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
reactivate = func() {
|
reactivate = func() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
s.registerHandler([]string{domain}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
if s.hostManager != nil {
|
||||||
@@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
nbdns.RootZone,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
@@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
|||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||||
|
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||||
|
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||||
|
|
||||||
|
for _, group := range nsGroups {
|
||||||
|
if group.Primary {
|
||||||
|
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range group.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domainMap[domain] = append(domainMap[domain], group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []nsGroupsByDomain
|
||||||
|
for domain, groups := range domainMap {
|
||||||
|
result = append(result, nsGroupsByDomain{
|
||||||
|
domain: domain,
|
||||||
|
groups: groups,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@@ -88,6 +89,18 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
|
var srvs []string
|
||||||
|
for _, srv := range servers {
|
||||||
|
srvs = append(srvs, getNSHostPort(srv))
|
||||||
|
}
|
||||||
|
return &upstreamResolverBase{
|
||||||
|
domain: domain,
|
||||||
|
upstreamServers: srvs,
|
||||||
|
cancel: func() {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
dummyHandler.id(): handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
initSerial: 0,
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
inputSerial: 1,
|
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
CustomZones: []nbdns.CustomZone{
|
CustomZones: []nbdns.CustomZone{
|
||||||
@@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"local-resolver": handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
@@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
@@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
@@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
|
"id1": handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &localResolver{},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
@@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
handlerPriorities: make(map[string]int),
|
hostManager: hostManager,
|
||||||
hostManager: hostManager,
|
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
{false, "domain0", false},
|
{false, "domain0", false},
|
||||||
@@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil, 0)
|
||||||
|
|
||||||
deactivate(nil)
|
deactivate(nil)
|
||||||
expected := "domain0,domain2"
|
expected := "domain0,domain2"
|
||||||
@@ -849,7 +912,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface)
|
pf, err := uspfilter.Create(wgIface, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
Subdomains: true,
|
Subdomains: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockHandler struct {
|
||||||
|
Id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
|
func (m *mockHandler) stop() {}
|
||||||
|
func (m *mockHandler) probeAvailability() {}
|
||||||
|
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||||
|
|
||||||
|
type mockService struct{}
|
||||||
|
|
||||||
|
func (m *mockService) Listen() error { return nil }
|
||||||
|
func (m *mockService) Stop() {}
|
||||||
|
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
||||||
|
func (m *mockService) RuntimePort() int { return 53 }
|
||||||
|
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||||
|
func (m *mockService) DeregisterMux(string) {}
|
||||||
|
|
||||||
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||||
|
baseMatchHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseRootHandlers := registeredHandlerMap{
|
||||||
|
"upstream-root1": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
"upstream-root2": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseMixedHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
"upstream-other": {
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialHandlers registeredHandlerMap
|
||||||
|
updates []handlerWrapper
|
||||||
|
expectedHandlers map[string]string // map[handlerID]domain
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Remove group1 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group2 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove group2 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group1 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in first position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Add group3 with highest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain + 1,
|
||||||
|
},
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in last position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
// Add group3 with lowest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
// Root zone tests
|
||||||
|
{
|
||||||
|
name: "Remove root1 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root2": ".",
|
||||||
|
},
|
||||||
|
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove root2 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
},
|
||||||
|
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in first position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault + 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in last position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
// Mixed domain tests
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - remove one of duplicate domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - add new domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "new.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-new",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
"upstream-new": "new.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &DefaultServer{
|
||||||
|
dnsMuxMap: tt.initialHandlers,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the update
|
||||||
|
server.updateMux(tt.updates)
|
||||||
|
|
||||||
|
// Verify the results
|
||||||
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||||
|
"Number of handlers after update doesn't match expected")
|
||||||
|
|
||||||
|
// Check each expected handler
|
||||||
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
|
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||||
|
assert.True(t, exists, "Expected handler %s not found", id)
|
||||||
|
if exists {
|
||||||
|
assert.Equal(t, expectedDomain, handler.domain,
|
||||||
|
"Domain mismatch for handler %s", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected handlers exist
|
||||||
|
for handlerID := range server.dnsMuxMap {
|
||||||
|
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||||
|
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the handlerChain state and order
|
||||||
|
previousPriority := 0
|
||||||
|
for _, chainEntry := range server.handlerChain.handlers {
|
||||||
|
// Verify priority order
|
||||||
|
if previousPriority > 0 {
|
||||||
|
assert.True(t, chainEntry.Priority <= previousPriority,
|
||||||
|
"Handlers in chain not properly ordered by priority")
|
||||||
|
}
|
||||||
|
previousPriority = chainEntry.Priority
|
||||||
|
|
||||||
|
// Verify handler exists in mux
|
||||||
|
foundInMux := false
|
||||||
|
for _, muxEntry := range server.dnsMuxMap {
|
||||||
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||||
|
foundInMux = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundInMux,
|
||||||
|
"Handler in chain not found in dnsMuxMap")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
Guid string
|
Guid string
|
||||||
|
GPO bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -13,9 +14,9 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
manager, err := newHostManagerWithGuid(s.Guid)
|
manager := ®istryConfigurator{
|
||||||
if err != nil {
|
guid: s.Guid,
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
gpo: s.GPO,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
|
|||||||
@@ -2,9 +2,13 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -40,6 +44,7 @@ type upstreamResolverBase struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
|
domain string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
successCount atomic.Int32
|
successCount atomic.Int32
|
||||||
@@ -53,12 +58,13 @@ type upstreamResolverBase struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: upstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
@@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
|
|||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (u *upstreamResolverBase) id() handlerID {
|
||||||
|
servers := slices.Clone(u.upstreamServers)
|
||||||
|
slices.Sort(servers)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
hash.Write([]byte(u.domain + ":"))
|
||||||
|
hash.Write([]byte(strings.Join(servers, ",")))
|
||||||
|
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
r.SetEdns0(4096, false)
|
||||||
@@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
|
log.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("got an error while connecting to upstream")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
continue
|
||||||
Error("got other error while querying the upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil {
|
if rm == nil || !rm.Response {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("no response from upstream")
|
continue
|
||||||
return
|
|
||||||
}
|
|
||||||
// those checks need to be independent of each other due to memory address issues
|
|
||||||
if !rm.Response {
|
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
|
||||||
Warn("no response from upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
err = w.WriteMsg(rm)
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
if err != nil {
|
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||||
|
|||||||
@@ -27,8 +27,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
hostsDNSHolder: hostsDNSHolder,
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,8 +30,9 @@ func newUpstreamResolver(
|
|||||||
net *net.IPNet,
|
net *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cancelCTX bool
|
cancelCTX bool
|
||||||
expectedAnswer string
|
expectedAnswer string
|
||||||
|
acceptNXDomain bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
@@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||||
timeout: 200 * time.Millisecond,
|
timeout: 200 * time.Millisecond,
|
||||||
responseShouldBeNil: true,
|
acceptNXDomain: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||||
@@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// should resolve if first upstream times out
|
|
||||||
// should not write when both fails
|
|
||||||
// should not resolve if parent context is canceled
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
@@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
t.Fatalf("should write a response message")
|
t.Fatalf("should write a response message")
|
||||||
}
|
}
|
||||||
|
|
||||||
foundAnswer := false
|
if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
|
||||||
for _, answer := range responseMSG.Answer {
|
return
|
||||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
|
||||||
foundAnswer = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !foundAnswer {
|
if testCase.expectedAnswer != "" {
|
||||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
foundAnswer := false
|
||||||
|
for _, answer := range responseMSG.Answer {
|
||||||
|
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||||
|
foundAnswer = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundAnswer {
|
||||||
|
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,13 +42,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
@@ -193,6 +193,10 @@ type Peer struct {
|
|||||||
WgAllowedIps string
|
WgAllowedIps string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type localIpUpdater interface {
|
||||||
|
UpdateLocalIPs() error
|
||||||
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
clientCtx context.Context,
|
clientCtx context.Context,
|
||||||
@@ -433,7 +437,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -883,6 +887,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.firewall != nil {
|
||||||
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
|
log.Errorf("failed to update local IPs: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DNS forwarder
|
// DNS forwarder
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
@@ -1446,6 +1458,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
return e.routeManager
|
return e.routeManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFirewallManager returns the firewall manager
|
||||||
|
func (e *Engine) GetFirewallManager() manager.Manager {
|
||||||
|
return e.firewall
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1657,6 +1674,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
return nm, nil
|
return nm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWgAddr returns the wireguard address
|
||||||
|
func (e *Engine) GetWgAddr() net.IP {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.wgInterface.Address().IP
|
||||||
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||||
if !enabled {
|
if !enabled {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -28,12 +29,28 @@ import (
|
|||||||
|
|
||||||
type ConnPriority int
|
type ConnPriority int
|
||||||
|
|
||||||
|
func (cp ConnPriority) String() string {
|
||||||
|
switch cp {
|
||||||
|
case connPriorityNone:
|
||||||
|
return "None"
|
||||||
|
case connPriorityRelay:
|
||||||
|
return "PriorityRelay"
|
||||||
|
case connPriorityICETurn:
|
||||||
|
return "PriorityICETurn"
|
||||||
|
case connPriorityICEP2P:
|
||||||
|
return "PriorityICEP2P"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
|
|
||||||
|
connPriorityNone ConnPriority = 0
|
||||||
connPriorityRelay ConnPriority = 1
|
connPriorityRelay ConnPriority = 1
|
||||||
connPriorityICETurn ConnPriority = 1
|
connPriorityICETurn ConnPriority = 2
|
||||||
connPriorityICEP2P ConnPriority = 2
|
connPriorityICEP2P ConnPriority = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
@@ -66,14 +83,6 @@ type ConnConfig struct {
|
|||||||
ICEConfig icemaker.Config
|
ICEConfig icemaker.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerCallbacks struct {
|
|
||||||
OnRelayReadyCallback func(info RelayConnInfo)
|
|
||||||
OnRelayStatusChanged func(ConnStatus)
|
|
||||||
|
|
||||||
OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
|
|
||||||
OnICEStatusChanged func(ConnStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -135,21 +144,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
semaphore: semaphore,
|
semaphore: semaphore,
|
||||||
}
|
}
|
||||||
|
|
||||||
rFns := WorkerRelayCallbacks{
|
|
||||||
OnConnReady: conn.relayConnectionIsReady,
|
|
||||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
wFns := WorkerICECallbacks{
|
|
||||||
OnConnReady: conn.iCEConnectionIsReady,
|
|
||||||
OnStatusChanged: conn.onWorkerICEStateDisconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := isController(config)
|
ctrl := isController(config)
|
||||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
|
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
|
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -304,7 +303,7 @@ func (conn *Conn) GetKey() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -317,9 +316,10 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||||
|
// todo consider to remove this check
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
|
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||||
conn.statusICE.Set(StatusConnected)
|
conn.statusICE.Set(StatusConnected)
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
@@ -375,8 +375,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo review to make sense to handle connecting and disconnected status also?
|
func (conn *Conn) onICEStateDisconnected() {
|
||||||
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -384,7 +383,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
conn.log.Tracef("ICE connection state changed to disconnected")
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
if conn.wgProxyICE != nil {
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
@@ -394,7 +393,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
@@ -402,12 +401,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
|
} else {
|
||||||
|
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||||
|
conn.currentConnPriority = connPriorityNone
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != StatusDisconnected
|
||||||
conn.statusICE.Set(newState)
|
if changed {
|
||||||
|
conn.guard.SetICEConnDisconnected()
|
||||||
conn.guard.SetICEConnDisconnected(changed)
|
}
|
||||||
|
conn.statusICE.Set(StatusDisconnected)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -422,7 +425,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -444,7 +447,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
|
|
||||||
if conn.iceP2PIsActive() {
|
if conn.iceP2PIsActive() {
|
||||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
@@ -474,7 +477,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onWorkerRelayStateDisconnected() {
|
func (conn *Conn) onRelayDisconnected() {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -497,8 +500,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
|
if changed {
|
||||||
|
conn.guard.SetRelayedConnDisconnected()
|
||||||
|
}
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
conn.guard.SetRelayedConnDisconnected(changed)
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ type Guard struct {
|
|||||||
isConnectedOnAllWay isConnectedFunc
|
isConnectedOnAllWay isConnectedFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
srWatcher *SRWatcher
|
srWatcher *SRWatcher
|
||||||
relayedConnDisconnected chan bool
|
relayedConnDisconnected chan struct{}
|
||||||
iCEConnDisconnected chan bool
|
iCEConnDisconnected chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||||
@@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
|
|||||||
isConnectedOnAllWay: isConnectedFn,
|
isConnectedOnAllWay: isConnectedFn,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
srWatcher: srWatcher,
|
srWatcher: srWatcher,
|
||||||
relayedConnDisconnected: make(chan bool, 1),
|
relayedConnDisconnected: make(chan struct{}, 1),
|
||||||
iCEConnDisconnected: make(chan bool, 1),
|
iCEConnDisconnected: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
|
func (g *Guard) SetRelayedConnDisconnected() {
|
||||||
select {
|
select {
|
||||||
case g.relayedConnDisconnected <- changed:
|
case g.relayedConnDisconnected <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) SetICEConnDisconnected(changed bool) {
|
func (g *Guard) SetICEConnDisconnected() {
|
||||||
select {
|
select {
|
||||||
case g.iCEConnDisconnected <- changed:
|
case g.iCEConnDisconnected <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
|||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
}
|
}
|
||||||
|
|
||||||
case changed := <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.prepareExponentTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
|
|
||||||
case changed := <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.prepareExponentTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
@@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
|||||||
g.log.Infof("start listen for reconnect events...")
|
g.log.Infof("start listen for reconnect events...")
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case changed := <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
g.log.Debugf("Relay connection changed, triggering reconnect")
|
||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
case changed := <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("ICE state changed, try to send new offer")
|
g.log.Debugf("ICE state changed, try to send new offer")
|
||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
case <-srReconnectedChan:
|
case <-srReconnectedChan:
|
||||||
|
|||||||
@@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return d.nsGroupStates
|
|
||||||
|
// shallow copy is good enough, as slices fields are currently not updated
|
||||||
|
return slices.Clone(d.nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
|
|||||||
154
client/internal/peer/wg_watcher.go
Normal file
154
client/internal/peer/wg_watcher.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wgHandshakePeriod = 3 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
wgHandshakeOvertime = 30 * time.Second // allowed delay in network
|
||||||
|
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
|
||||||
|
)
|
||||||
|
|
||||||
|
type WGInterfaceStater interface {
|
||||||
|
GetStats(key string) (configurer.WGStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WGWatcher struct {
|
||||||
|
log *log.Entry
|
||||||
|
wgIfaceStater WGInterfaceStater
|
||||||
|
peerKey string
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
ctxLock sync.Mutex
|
||||||
|
waitGroup sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
|
||||||
|
return &WGWatcher{
|
||||||
|
log: log,
|
||||||
|
wgIfaceStater: wgIfaceStater,
|
||||||
|
peerKey: peerKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||||
|
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||||
|
w.log.Debugf("enable WireGuard watcher")
|
||||||
|
w.ctxLock.Lock()
|
||||||
|
defer w.ctxLock.Unlock()
|
||||||
|
|
||||||
|
if w.ctx != nil && w.ctx.Err() == nil {
|
||||||
|
w.log.Errorf("WireGuard watcher already enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||||
|
w.ctx = ctx
|
||||||
|
w.ctxCancel = ctxCancel
|
||||||
|
|
||||||
|
initialHandshake, err := w.wgState()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.waitGroup.Add(1)
|
||||||
|
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||||
|
func (w *WGWatcher) DisableWgWatcher() {
|
||||||
|
w.ctxLock.Lock()
|
||||||
|
defer w.ctxLock.Unlock()
|
||||||
|
|
||||||
|
if w.ctxCancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.log.Debugf("disable WireGuard watcher")
|
||||||
|
|
||||||
|
w.ctxCancel()
|
||||||
|
w.ctxCancel = nil
|
||||||
|
w.waitGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||||
|
w.log.Infof("WireGuard watcher started")
|
||||||
|
defer w.waitGroup.Done()
|
||||||
|
|
||||||
|
timer := time.NewTimer(wgHandshakeOvertime)
|
||||||
|
defer timer.Stop()
|
||||||
|
defer ctxCancel()
|
||||||
|
|
||||||
|
lastHandshake := initialHandshake
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
handshake, ok := w.handshakeCheck(lastHandshake)
|
||||||
|
if !ok {
|
||||||
|
onDisconnectedFn()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastHandshake = *handshake
|
||||||
|
|
||||||
|
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||||
|
timer.Reset(resetTime)
|
||||||
|
|
||||||
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
|
case <-ctx.Done():
|
||||||
|
w.log.Infof("WireGuard watcher stopped")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one
|
||||||
|
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
||||||
|
handshake, err := w.wgState()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to read wg stats: %v", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
||||||
|
|
||||||
|
// the current know handshake did not change
|
||||||
|
if handshake.Equal(lastHandshake) {
|
||||||
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case if the machine is suspended, the handshake time will be in the past
|
||||||
|
if handshake.Add(checkPeriod).Before(time.Now()) {
|
||||||
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// error handling for handshake time in the future
|
||||||
|
if handshake.After(time.Now()) {
|
||||||
|
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &handshake, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGWatcher) wgState() (time.Time, error) {
|
||||||
|
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return wgState.LastHandshake, nil
|
||||||
|
}
|
||||||
98
client/internal/peer/wg_watcher_test.go
Normal file
98
client/internal/peer/wg_watcher_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocWgIface struct {
|
||||||
|
initial bool
|
||||||
|
lastHandshake time.Time
|
||||||
|
stop bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
|
||||||
|
if !m.initial {
|
||||||
|
m.initial = true
|
||||||
|
return configurer.WGStats{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.stop {
|
||||||
|
m.lastHandshake = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := configurer.WGStats{
|
||||||
|
LastHandshake: m.lastHandshake,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocWgIface) disconnect() {
|
||||||
|
m.stop = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||||
|
checkPeriod = 5 * time.Second
|
||||||
|
wgHandshakeOvertime = 1 * time.Second
|
||||||
|
|
||||||
|
mlog := log.WithField("peer", "tet")
|
||||||
|
mocWgIface := &MocWgIface{}
|
||||||
|
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
onDisconnected := make(chan struct{}, 1)
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {
|
||||||
|
mlog.Infof("onDisconnectedFn")
|
||||||
|
onDisconnected <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
// wait for initial reading
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
mocWgIface.disconnect()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-onDisconnected:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||||
|
checkPeriod = 5 * time.Second
|
||||||
|
wgHandshakeOvertime = 1 * time.Second
|
||||||
|
|
||||||
|
mlog := log.WithField("peer", "tet")
|
||||||
|
mocWgIface := &MocWgIface{}
|
||||||
|
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
onDisconnected := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {})
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {
|
||||||
|
onDisconnected <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
mocWgIface.disconnect()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-onDisconnected:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
}
|
||||||
@@ -31,20 +31,15 @@ type ICEConnInfo struct {
|
|||||||
RelayedOnLocal bool
|
RelayedOnLocal bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerICECallbacks struct {
|
|
||||||
OnConnReady func(ConnPriority, ICEConnInfo)
|
|
||||||
OnStatusChanged func(ConnStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WorkerICE struct {
|
type WorkerICE struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
|
conn *Conn
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
hasRelayOnLocally bool
|
hasRelayOnLocally bool
|
||||||
conn WorkerICECallbacks
|
|
||||||
|
|
||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
muxAgent sync.Mutex
|
muxAgent sync.Mutex
|
||||||
@@ -60,16 +55,16 @@ type WorkerICE struct {
|
|||||||
lastKnownState ice.ConnectionState
|
lastKnownState ice.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||||
w := &WorkerICE{
|
w := &WorkerICE{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
conn: conn,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
iFaceDiscover: ifaceDiscover,
|
iFaceDiscover: ifaceDiscover,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
hasRelayOnLocally: hasRelayOnLocally,
|
hasRelayOnLocally: hasRelayOnLocally,
|
||||||
conn: callBacks,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
||||||
@@ -154,8 +149,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
Relayed: isRelayed(pair),
|
Relayed: isRelayed(pair),
|
||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
w.log.Debugf("on ICE conn read to use ready")
|
w.log.Debugf("on ICE conn is ready to use")
|
||||||
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
@@ -220,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.OnStatusChanged(StatusDisconnected)
|
w.conn.onICEStateDisconnected()
|
||||||
}
|
}
|
||||||
w.closeAgent(agentCancel)
|
w.closeAgent(agentCancel)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -6,52 +6,41 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
wgHandshakePeriod = 3 * time.Minute
|
|
||||||
wgHandshakeOvertime = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type RelayConnInfo struct {
|
type RelayConnInfo struct {
|
||||||
relayedConn net.Conn
|
relayedConn net.Conn
|
||||||
rosenpassPubKey []byte
|
rosenpassPubKey []byte
|
||||||
rosenpassAddr string
|
rosenpassAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerRelayCallbacks struct {
|
|
||||||
OnConnReady func(RelayConnInfo)
|
|
||||||
OnDisconnected func()
|
|
||||||
}
|
|
||||||
|
|
||||||
type WorkerRelay struct {
|
type WorkerRelay struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isController bool
|
isController bool
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
|
conn *Conn
|
||||||
relayManager relayClient.ManagerService
|
relayManager relayClient.ManagerService
|
||||||
callBacks WorkerRelayCallbacks
|
|
||||||
|
|
||||||
relayedConn net.Conn
|
relayedConn net.Conn
|
||||||
relayLock sync.Mutex
|
relayLock sync.Mutex
|
||||||
ctxWgWatch context.Context
|
|
||||||
ctxCancelWgWatch context.CancelFunc
|
|
||||||
ctxLock sync.Mutex
|
|
||||||
|
|
||||||
relaySupportedOnRemotePeer atomic.Bool
|
relaySupportedOnRemotePeer atomic.Bool
|
||||||
|
|
||||||
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
log: log,
|
log: log,
|
||||||
isController: ctrl,
|
isController: ctrl,
|
||||||
config: config,
|
config: config,
|
||||||
|
conn: conn,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
callBacks: callbacks,
|
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -87,7 +76,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
w.relayedConn = relayedConn
|
w.relayedConn = relayedConn
|
||||||
w.relayLock.Unlock()
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
|
err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add close listener: %s", err)
|
log.Errorf("failed to add close listener: %s", err)
|
||||||
_ = relayedConn.Close()
|
_ = relayedConn.Close()
|
||||||
@@ -95,7 +84,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||||
go w.callBacks.OnConnReady(RelayConnInfo{
|
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||||
relayedConn: relayedConn,
|
relayedConn: relayedConn,
|
||||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||||
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
@@ -103,32 +92,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||||
w.log.Debugf("enable WireGuard watcher")
|
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
|
||||||
w.ctxLock.Lock()
|
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(ctx)
|
|
||||||
w.ctxWgWatch = ctx
|
|
||||||
w.ctxCancelWgWatch = ctxCancel
|
|
||||||
|
|
||||||
w.wgStateCheck(ctx, ctxCancel)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) DisableWgWatcher() {
|
func (w *WorkerRelay) DisableWgWatcher() {
|
||||||
w.ctxLock.Lock()
|
w.wgWatcher.DisableWgWatcher()
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxCancelWgWatch == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Debugf("disable WireGuard watcher")
|
|
||||||
|
|
||||||
w.ctxCancelWgWatch()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||||
@@ -150,57 +118,17 @@ func (w *WorkerRelay) CloseConn() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := w.relayedConn.Close()
|
if err := w.relayedConn.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to close relay connection: %v", err)
|
w.log.Warnf("failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
func (w *WorkerRelay) onWGDisconnected() {
|
||||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
|
w.relayLock.Lock()
|
||||||
w.log.Debugf("WireGuard watcher started")
|
_ = w.relayedConn.Close()
|
||||||
lastHandshake, err := w.wgState()
|
w.relayLock.Unlock()
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to read wg stats: %v", err)
|
|
||||||
lastHandshake = time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(lastHandshake time.Time) {
|
|
||||||
timer := time.NewTimer(wgHandshakeOvertime)
|
|
||||||
defer timer.Stop()
|
|
||||||
defer ctxCancel()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
handshake, err := w.wgState()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to read wg stats: %v", err)
|
|
||||||
timer.Reset(wgHandshakeOvertime)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
|
||||||
|
|
||||||
if handshake.Equal(lastHandshake) {
|
|
||||||
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
|
||||||
w.relayLock.Lock()
|
|
||||||
_ = w.relayedConn.Close()
|
|
||||||
w.relayLock.Unlock()
|
|
||||||
w.callBacks.OnDisconnected()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
|
|
||||||
lastHandshake = handshake
|
|
||||||
timer.Reset(resetTime)
|
|
||||||
case <-ctx.Done():
|
|
||||||
w.log.Debugf("WireGuard watcher stopped")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(lastHandshake)
|
|
||||||
|
|
||||||
|
w.conn.onRelayDisconnected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||||
@@ -217,20 +145,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
|||||||
return remoteRelayAddress
|
return remoteRelayAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) wgState() (time.Time, error) {
|
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||||
wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
|
w.wgWatcher.DisableWgWatcher()
|
||||||
if err != nil {
|
go w.conn.onRelayDisconnected()
|
||||||
return time.Time{}, err
|
|
||||||
}
|
|
||||||
return wgState.LastHandshake, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerRelay) onRelayMGDisconnected() {
|
|
||||||
w.ctxLock.Lock()
|
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxCancelWgWatch != nil {
|
|
||||||
w.ctxCancelWgWatch()
|
|
||||||
}
|
|
||||||
go w.callBacks.OnDisconnected()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,13 +113,14 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
disableServerRoutes: config.DisableServerRoutes,
|
disableServerRoutes: config.DisableServerRoutes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||||
|
dm.setupRefCounters(useNoop)
|
||||||
|
|
||||||
// don't proceed with client routes if it is disabled
|
// don't proceed with client routes if it is disabled
|
||||||
if config.DisableClientRoutes {
|
if config.DisableClientRoutes {
|
||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.setupRefCounters()
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
@@ -127,7 +128,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters() {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
@@ -137,7 +138,7 @@ func (m *DefaultManager) setupRefCounters() {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if useNoop {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(netip.Prefix, struct{}) (struct{}, error) {
|
func(netip.Prefix, struct{}) (struct{}, error) {
|
||||||
return struct{}{}, refcounter.ErrIgnore
|
return struct{}{}, refcounter.ErrIgnore
|
||||||
@@ -285,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
}
|
}
|
||||||
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter == nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.clientRoutes = newClientRoutesIDMap
|
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
|
||||||
|
return fmt.Errorf("update routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -422,11 +423,6 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
haID := newRoute.GetHAUniqueID()
|
haID := newRoute.GetHAUniqueID()
|
||||||
if newRoute.Peer == m.pubKey {
|
if newRoute.Peer == m.pubKey {
|
||||||
ownNetworkIDs[haID] = true
|
ownNetworkIDs[haID] = true
|
||||||
// only linux is supported for now
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -454,7 +450,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isRouteSupported(route *route.Route) bool {
|
func isRouteSupported(route *route.Route) bool {
|
||||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(m.routes) > 0 {
|
if len(m.routes) > 0 {
|
||||||
err := systemops.EnableIPForwarding()
|
if err := systemops.EnableIPForwarding(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("enable ip forwarding: %w", err)
|
||||||
return err
|
}
|
||||||
|
if err := m.firewall.EnableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("enable routing: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.firewall.DisableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("disable routing: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,20 +53,6 @@ type ruleParams struct {
|
|||||||
description string
|
description string
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
|
||||||
func isLegacy() bool {
|
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
|
||||||
}
|
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
|
||||||
func setIsLegacy(b bool) {
|
|
||||||
if b {
|
|
||||||
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
|
||||||
} else {
|
|
||||||
os.Unsetenv("NB_USE_LEGACY_ROUTING")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSetupRules() []ruleParams {
|
func getSetupRules() []ruleParams {
|
||||||
return []ruleParams{
|
return []ruleParams{
|
||||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||||
@@ -87,7 +73,7 @@ func getSetupRules() []ruleParams {
|
|||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
@@ -103,11 +89,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
|
||||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
|
||||||
setIsLegacy(true)
|
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -130,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +149,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericAddVPNRoute(prefix, intf)
|
return r.genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,7 +172,7 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,7 +485,7 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return GetRoutesFromTable()
|
return GetRoutesFromTable()
|
||||||
}
|
}
|
||||||
return nil, ErrRoutingIsSeparate
|
return nil, ErrRoutingIsSeparate
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ var testCases = []testCase{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
|
nbnet.Init()
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
// todo resolve test execution on freebsd
|
// todo resolve test execution on freebsd
|
||||||
if runtime.GOOS == "freebsd" {
|
if runtime.GOOS == "freebsd" {
|
||||||
|
|||||||
@@ -2571,6 +2571,330 @@ func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
|
|||||||
return file_daemon_proto_rawDescGZIP(), []int{39}
|
return file_daemon_proto_rawDescGZIP(), []int{39}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TCPFlags struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
Syn bool `protobuf:"varint,1,opt,name=syn,proto3" json:"syn,omitempty"`
|
||||||
|
Ack bool `protobuf:"varint,2,opt,name=ack,proto3" json:"ack,omitempty"`
|
||||||
|
Fin bool `protobuf:"varint,3,opt,name=fin,proto3" json:"fin,omitempty"`
|
||||||
|
Rst bool `protobuf:"varint,4,opt,name=rst,proto3" json:"rst,omitempty"`
|
||||||
|
Psh bool `protobuf:"varint,5,opt,name=psh,proto3" json:"psh,omitempty"`
|
||||||
|
Urg bool `protobuf:"varint,6,opt,name=urg,proto3" json:"urg,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) Reset() {
|
||||||
|
*x = TCPFlags{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_daemon_proto_msgTypes[40]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TCPFlags) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *TCPFlags) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[40]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead.
|
||||||
|
func (*TCPFlags) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{40}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetSyn() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Syn
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetAck() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Ack
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetFin() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Fin
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetRst() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Rst
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetPsh() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Psh
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TCPFlags) GetUrg() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Urg
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type TracePacketRequest struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
SourceIp string `protobuf:"bytes,1,opt,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"`
|
||||||
|
DestinationIp string `protobuf:"bytes,2,opt,name=destination_ip,json=destinationIp,proto3" json:"destination_ip,omitempty"`
|
||||||
|
Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"`
|
||||||
|
SourcePort uint32 `protobuf:"varint,4,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
|
||||||
|
DestinationPort uint32 `protobuf:"varint,5,opt,name=destination_port,json=destinationPort,proto3" json:"destination_port,omitempty"`
|
||||||
|
Direction string `protobuf:"bytes,6,opt,name=direction,proto3" json:"direction,omitempty"`
|
||||||
|
TcpFlags *TCPFlags `protobuf:"bytes,7,opt,name=tcp_flags,json=tcpFlags,proto3,oneof" json:"tcp_flags,omitempty"`
|
||||||
|
IcmpType *uint32 `protobuf:"varint,8,opt,name=icmp_type,json=icmpType,proto3,oneof" json:"icmp_type,omitempty"`
|
||||||
|
IcmpCode *uint32 `protobuf:"varint,9,opt,name=icmp_code,json=icmpCode,proto3,oneof" json:"icmp_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) Reset() {
|
||||||
|
*x = TracePacketRequest{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_daemon_proto_msgTypes[41]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TracePacketRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[41]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*TracePacketRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{41}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetSourceIp() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.SourceIp
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetDestinationIp() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DestinationIp
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetProtocol() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Protocol
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetSourcePort() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.SourcePort
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetDestinationPort() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.DestinationPort
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetDirection() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Direction
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetTcpFlags() *TCPFlags {
|
||||||
|
if x != nil {
|
||||||
|
return x.TcpFlags
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetIcmpType() uint32 {
|
||||||
|
if x != nil && x.IcmpType != nil {
|
||||||
|
return *x.IcmpType
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketRequest) GetIcmpCode() uint32 {
|
||||||
|
if x != nil && x.IcmpCode != nil {
|
||||||
|
return *x.IcmpCode
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceStage struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||||
|
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
|
||||||
|
Allowed bool `protobuf:"varint,3,opt,name=allowed,proto3" json:"allowed,omitempty"`
|
||||||
|
ForwardingDetails *string `protobuf:"bytes,4,opt,name=forwarding_details,json=forwardingDetails,proto3,oneof" json:"forwarding_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) Reset() {
|
||||||
|
*x = TraceStage{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_daemon_proto_msgTypes[42]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TraceStage) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *TraceStage) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[42]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use TraceStage.ProtoReflect.Descriptor instead.
|
||||||
|
func (*TraceStage) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{42}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) GetName() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Name
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) GetMessage() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Message
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) GetAllowed() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Allowed
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TraceStage) GetForwardingDetails() string {
|
||||||
|
if x != nil && x.ForwardingDetails != nil {
|
||||||
|
return *x.ForwardingDetails
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type TracePacketResponse struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
Stages []*TraceStage `protobuf:"bytes,1,rep,name=stages,proto3" json:"stages,omitempty"`
|
||||||
|
FinalDisposition bool `protobuf:"varint,2,opt,name=final_disposition,json=finalDisposition,proto3" json:"final_disposition,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketResponse) Reset() {
|
||||||
|
*x = TracePacketResponse{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_daemon_proto_msgTypes[43]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TracePacketResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *TracePacketResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[43]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*TracePacketResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{43}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketResponse) GetStages() []*TraceStage {
|
||||||
|
if x != nil {
|
||||||
|
return x.Stages
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TracePacketResponse) GetFinalDisposition() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.FinalDisposition
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var File_daemon_proto protoreflect.FileDescriptor
|
var File_daemon_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
var file_daemon_proto_rawDesc = []byte{
|
var file_daemon_proto_rawDesc = []byte{
|
||||||
@@ -2920,87 +3244,141 @@ var file_daemon_proto_rawDesc = []byte{
|
|||||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22,
|
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22,
|
||||||
0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
|
0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
|
||||||
0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||||
0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12,
|
0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12,
|
||||||
0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05,
|
0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79,
|
||||||
0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c,
|
0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03,
|
||||||
0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a,
|
0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08,
|
||||||
0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10,
|
0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01,
|
||||||
0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05,
|
0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05,
|
||||||
0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d,
|
0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67,
|
||||||
0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67,
|
0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12,
|
||||||
0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69,
|
0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||||
0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18,
|
||||||
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12,
|
||||||
0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
|
0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69,
|
||||||
0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53,
|
0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61,
|
||||||
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c,
|
0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63,
|
||||||
|
0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63,
|
||||||
|
0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
|
||||||
|
0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
|
||||||
|
0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69,
|
||||||
|
0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64,
|
||||||
|
0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c,
|
||||||
|
0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28,
|
||||||
|
0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09,
|
||||||
|
0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32,
|
||||||
|
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67,
|
||||||
|
0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01,
|
||||||
|
0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20,
|
||||||
|
0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88,
|
||||||
|
0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18,
|
||||||
|
0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64,
|
||||||
|
0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61,
|
||||||
|
0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65,
|
||||||
|
0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f,
|
||||||
|
0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a,
|
||||||
|
0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d,
|
||||||
|
0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01,
|
||||||
|
0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61,
|
||||||
|
0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c,
|
||||||
|
0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64,
|
||||||
|
0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28,
|
||||||
|
0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44,
|
||||||
|
0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f,
|
||||||
|
0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73,
|
||||||
|
0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52,
|
||||||
|
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65,
|
||||||
|
0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||||
|
0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61,
|
||||||
|
0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73,
|
||||||
|
0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10,
|
||||||
|
0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e,
|
||||||
|
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
|
||||||
|
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
|
||||||
|
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
|
||||||
|
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
|
||||||
|
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
|
||||||
|
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
|
||||||
|
0x43, 0x45, 0x10, 0x07, 0x32, 0xdd, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
|
||||||
|
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
|
||||||
|
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
|
||||||
|
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
|
||||||
|
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
|
||||||
|
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
|
||||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
|
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
|
||||||
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d,
|
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
|
||||||
0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
|
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
|
||||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
|
||||||
0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a,
|
0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
|
||||||
0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
|
||||||
0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16,
|
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
|
||||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65,
|
0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
|
||||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e,
|
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
|
||||||
0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65,
|
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
|
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
|
||||||
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a,
|
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||||
0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65,
|
0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
|
||||||
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71,
|
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
|
||||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65,
|
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||||
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||||
0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
|
0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
|
||||||
0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e,
|
0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
|
||||||
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c,
|
0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b,
|
||||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77,
|
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77,
|
||||||
0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51,
|
0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
|
||||||
0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73,
|
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
|
||||||
0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
|
0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53,
|
||||||
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
|
0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e,
|
||||||
0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e,
|
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74,
|
||||||
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64,
|
||||||
0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74,
|
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77,
|
||||||
0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53,
|
||||||
0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71,
|
0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
|
||||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
|
||||||
0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70,
|
0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||||
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
|
0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
|
||||||
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
|
0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||||
0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
|
||||||
0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67,
|
0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75,
|
||||||
0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
|
0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
|
||||||
0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12,
|
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e,
|
||||||
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
|
0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
|
||||||
0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
|
0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
|
||||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
|
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
|
||||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
|
0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
|
||||||
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
|
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f,
|
||||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||||
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74,
|
0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c,
|
||||||
0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74,
|
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
||||||
0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e,
|
0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12,
|
||||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65,
|
0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61,
|
||||||
0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43,
|
0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65,
|
||||||
0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65,
|
||||||
0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71,
|
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61,
|
||||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c,
|
0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||||
0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||||
0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74,
|
0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e,
|
||||||
0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74,
|
0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
|
||||||
0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e,
|
0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a,
|
||||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61,
|
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74,
|
||||||
0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18,
|
0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
|
||||||
0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72,
|
0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52,
|
||||||
0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74,
|
||||||
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50,
|
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73,
|
||||||
0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
||||||
0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65,
|
0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73,
|
||||||
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65,
|
0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28,
|
||||||
0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a,
|
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
|
||||||
0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
|
||||||
|
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72,
|
||||||
|
0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||||
|
0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65,
|
||||||
|
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54,
|
||||||
|
0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||||
|
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
|
||||||
|
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -3016,7 +3394,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41)
|
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 45)
|
||||||
var file_daemon_proto_goTypes = []interface{}{
|
var file_daemon_proto_goTypes = []interface{}{
|
||||||
(LogLevel)(0), // 0: daemon.LogLevel
|
(LogLevel)(0), // 0: daemon.LogLevel
|
||||||
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
||||||
@@ -3059,16 +3437,20 @@ var file_daemon_proto_goTypes = []interface{}{
|
|||||||
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
|
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
|
||||||
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
|
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
|
||||||
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
|
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
|
||||||
nil, // 41: daemon.Network.ResolvedIPsEntry
|
(*TCPFlags)(nil), // 41: daemon.TCPFlags
|
||||||
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration
|
(*TracePacketRequest)(nil), // 42: daemon.TracePacketRequest
|
||||||
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp
|
(*TraceStage)(nil), // 43: daemon.TraceStage
|
||||||
|
(*TracePacketResponse)(nil), // 44: daemon.TracePacketResponse
|
||||||
|
nil, // 45: daemon.Network.ResolvedIPsEntry
|
||||||
|
(*durationpb.Duration)(nil), // 46: google.protobuf.Duration
|
||||||
|
(*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp
|
||||||
}
|
}
|
||||||
var file_daemon_proto_depIdxs = []int32{
|
var file_daemon_proto_depIdxs = []int32{
|
||||||
42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
46, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||||
43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
47, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||||
43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
47, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||||
42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
46, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||||
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||||
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||||
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||||
@@ -3076,48 +3458,52 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||||
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||||
25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||||
41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
45, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||||
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||||
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||||
32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
|
32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||||
24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
41, // 16: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||||
1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
43, // 17: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||||
3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
24, // 18: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||||
5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
1, // 19: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||||
7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
3, // 20: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||||
9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
5, // 21: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||||
11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
7, // 22: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||||
20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
9, // 23: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||||
22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
11, // 24: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||||
22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
20, // 25: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||||
26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
22, // 26: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||||
28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
22, // 27: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||||
30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
26, // 28: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||||
33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
28, // 29: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||||
35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
30, // 30: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||||
37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
33, // 31: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||||
39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
|
35, // 32: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||||
2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
37, // 33: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||||
4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
39, // 34: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
|
||||||
6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
42, // 35: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||||
8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
2, // 36: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||||
10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
4, // 37: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||||
12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
6, // 38: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||||
21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
8, // 39: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||||
23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
10, // 40: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||||
23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
12, // 41: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||||
27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
21, // 42: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||||
29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
23, // 43: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
23, // 44: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
27, // 45: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||||
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
29, // 46: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||||
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
31, // 47: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||||
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
|
34, // 48: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||||
33, // [33:49] is the sub-list for method output_type
|
36, // 49: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||||
17, // [17:33] is the sub-list for method input_type
|
38, // 50: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||||
17, // [17:17] is the sub-list for extension type_name
|
40, // 51: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
|
||||||
17, // [17:17] is the sub-list for extension extendee
|
44, // 52: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||||
0, // [0:17] is the sub-list for field type_name
|
36, // [36:53] is the sub-list for method output_type
|
||||||
|
19, // [19:36] is the sub-list for method input_type
|
||||||
|
19, // [19:19] is the sub-list for extension type_name
|
||||||
|
19, // [19:19] is the sub-list for extension extendee
|
||||||
|
0, // [0:19] is the sub-list for field type_name
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() { file_daemon_proto_init() }
|
func init() { file_daemon_proto_init() }
|
||||||
@@ -3606,15 +3992,65 @@ func file_daemon_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
file_daemon_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*TCPFlags); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_daemon_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*TracePacketRequest); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_daemon_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*TraceStage); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_daemon_proto_msgTypes[43].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*TracePacketResponse); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
|
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
|
||||||
|
file_daemon_proto_msgTypes[41].OneofWrappers = []interface{}{}
|
||||||
|
file_daemon_proto_msgTypes[42].OneofWrappers = []interface{}{}
|
||||||
type x struct{}
|
type x struct{}
|
||||||
out := protoimpl.TypeBuilder{
|
out := protoimpl.TypeBuilder{
|
||||||
File: protoimpl.DescBuilder{
|
File: protoimpl.DescBuilder{
|
||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: file_daemon_proto_rawDesc,
|
RawDescriptor: file_daemon_proto_rawDesc,
|
||||||
NumEnums: 1,
|
NumEnums: 1,
|
||||||
NumMessages: 41,
|
NumMessages: 45,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ service DaemonService {
|
|||||||
|
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
||||||
|
|
||||||
|
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -356,3 +358,36 @@ message SetNetworkMapPersistenceRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message SetNetworkMapPersistenceResponse {}
|
message SetNetworkMapPersistenceResponse {}
|
||||||
|
|
||||||
|
message TCPFlags {
|
||||||
|
bool syn = 1;
|
||||||
|
bool ack = 2;
|
||||||
|
bool fin = 3;
|
||||||
|
bool rst = 4;
|
||||||
|
bool psh = 5;
|
||||||
|
bool urg = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TracePacketRequest {
|
||||||
|
string source_ip = 1;
|
||||||
|
string destination_ip = 2;
|
||||||
|
string protocol = 3;
|
||||||
|
uint32 source_port = 4;
|
||||||
|
uint32 destination_port = 5;
|
||||||
|
string direction = 6;
|
||||||
|
optional TCPFlags tcp_flags = 7;
|
||||||
|
optional uint32 icmp_type = 8;
|
||||||
|
optional uint32 icmp_code = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TraceStage {
|
||||||
|
string name = 1;
|
||||||
|
string message = 2;
|
||||||
|
bool allowed = 3;
|
||||||
|
optional string forwarding_details = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TracePacketResponse {
|
||||||
|
repeated TraceStage stages = 1;
|
||||||
|
bool final_disposition = 2;
|
||||||
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type DaemonServiceClient interface {
|
|||||||
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
||||||
|
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -205,6 +206,15 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
|
||||||
|
out := new(TracePacketResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -242,6 +252,7 @@ type DaemonServiceServer interface {
|
|||||||
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
||||||
|
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,6 +308,9 @@ func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStat
|
|||||||
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -598,6 +612,24 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(TracePacketRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).TracePacket(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/TracePacket",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -669,6 +701,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "SetNetworkMapPersistence",
|
MethodName: "SetNetworkMapPersistence",
|
||||||
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "TracePacket",
|
||||||
|
Handler: _DaemonService_TracePacket_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
Streams: []grpc.StreamDesc{},
|
||||||
Metadata: "daemon.proto",
|
Metadata: "daemon.proto",
|
||||||
|
|||||||
@@ -538,7 +538,24 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
|
|
||||||
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("connect client not initialized")
|
||||||
|
}
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager == nil {
|
||||||
|
return nil, fmt.Errorf("firewall manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager.SetLogLevel(level)
|
||||||
|
|
||||||
log.Infof("Log level set to %s", level.String())
|
log.Infof("Log level set to %s", level.String())
|
||||||
|
|
||||||
return &proto.SetLogLevelResponse{}, nil
|
return &proto.SetLogLevelResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
123
client/server/trace.go
Normal file
123
client/server/trace.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type packetTracer interface {
|
||||||
|
TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("connect client not initialized")
|
||||||
|
}
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager == nil {
|
||||||
|
return nil, fmt.Errorf("firewall manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer, ok := fwManager.(packetTracer)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := net.ParseIP(req.GetSourceIp())
|
||||||
|
if req.GetSourceIp() == "self" {
|
||||||
|
srcIP = engine.GetWgAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||||
|
if req.GetDestinationIp() == "self" {
|
||||||
|
dstIP = engine.GetWgAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
if srcIP == nil || dstIP == nil {
|
||||||
|
return nil, fmt.Errorf("invalid IP address")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpState *uspfilter.TCPState
|
||||||
|
if flags := req.GetTcpFlags(); flags != nil {
|
||||||
|
tcpState = &uspfilter.TCPState{
|
||||||
|
SYN: flags.GetSyn(),
|
||||||
|
ACK: flags.GetAck(),
|
||||||
|
FIN: flags.GetFin(),
|
||||||
|
RST: flags.GetRst(),
|
||||||
|
PSH: flags.GetPsh(),
|
||||||
|
URG: flags.GetUrg(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dir fw.RuleDirection
|
||||||
|
switch req.GetDirection() {
|
||||||
|
case "in":
|
||||||
|
dir = fw.RuleDirectionIN
|
||||||
|
case "out":
|
||||||
|
dir = fw.RuleDirectionOUT
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid direction")
|
||||||
|
}
|
||||||
|
|
||||||
|
var protocol fw.Protocol
|
||||||
|
switch req.GetProtocol() {
|
||||||
|
case "tcp":
|
||||||
|
protocol = fw.ProtocolTCP
|
||||||
|
case "udp":
|
||||||
|
protocol = fw.ProtocolUDP
|
||||||
|
case "icmp":
|
||||||
|
protocol = fw.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid protocolcol")
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := &uspfilter.PacketBuilder{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: uint16(req.GetSourcePort()),
|
||||||
|
DstPort: uint16(req.GetDestinationPort()),
|
||||||
|
Direction: dir,
|
||||||
|
TCPState: tcpState,
|
||||||
|
ICMPType: uint8(req.GetIcmpType()),
|
||||||
|
ICMPCode: uint8(req.GetIcmpCode()),
|
||||||
|
}
|
||||||
|
trace, err := tracer.TracePacketFromBuilder(builder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("trace packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &proto.TracePacketResponse{}
|
||||||
|
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
stage := &proto.TraceStage{
|
||||||
|
Name: result.Stage.String(),
|
||||||
|
Message: result.Message,
|
||||||
|
Allowed: result.Allowed,
|
||||||
|
}
|
||||||
|
if result.ForwarderAction != nil {
|
||||||
|
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
||||||
|
stage.ForwardingDetails = &details
|
||||||
|
}
|
||||||
|
resp.Stages = append(resp.Stages, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(trace.Results) > 0 {
|
||||||
|
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
4
go.mod
4
go.mod
@@ -102,6 +102,7 @@ require (
|
|||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
gorm.io/driver/sqlite v1.5.7
|
gorm.io/driver/sqlite v1.5.7
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.25.12
|
||||||
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -237,7 +238,6 @@ require (
|
|||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 // indirect
|
|
||||||
k8s.io/apimachinery v0.26.2 // indirect
|
k8s.io/apimachinery v0.26.2 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -245,7 +245,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -535,8 +535,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
@@ -1250,8 +1250,8 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
|||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 h1:sCEaoA7ZmkuFwa2IR61pl4+RYZPwCJOiaSYT0k+BRf8=
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -33,7 +38,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAccounts_List_200(t *testing.T) {
|
func TestAccounts_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Account{testAccount})
|
retBytes, _ := json.Marshal([]api.Account{testAccount})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -47,7 +52,7 @@ func TestAccounts_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_List_Err(t *testing.T) {
|
func TestAccounts_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -62,7 +67,7 @@ func TestAccounts_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Update_200(t *testing.T) {
|
func TestAccounts_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -87,7 +92,7 @@ func TestAccounts_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Update_Err(t *testing.T) {
|
func TestAccounts_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -106,7 +111,7 @@ func TestAccounts_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Delete_200(t *testing.T) {
|
func TestAccounts_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -117,7 +122,7 @@ func TestAccounts_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Delete_Err(t *testing.T) {
|
func TestAccounts_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -131,7 +136,7 @@ func TestAccounts_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Integration_List(t *testing.T) {
|
func TestAccounts_Integration_List(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
accounts, err := c.Accounts.List(context.Background())
|
accounts, err := c.Accounts.List(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, accounts, 1)
|
assert.Len(t, accounts, 1)
|
||||||
@@ -141,7 +146,7 @@ func TestAccounts_Integration_List(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccounts_Integration_Update(t *testing.T) {
|
func TestAccounts_Integration_Update(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
accounts, err := c.Accounts.List(context.Background())
|
accounts, err := c.Accounts.List(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, accounts, 1)
|
assert.Len(t, accounts, 1)
|
||||||
@@ -157,7 +162,7 @@ func TestAccounts_Integration_Update(t *testing.T) {
|
|||||||
|
|
||||||
// Account deletion on MySQL and PostgreSQL databases causes unknown errors
|
// Account deletion on MySQL and PostgreSQL databases causes unknown errors
|
||||||
// func TestAccounts_Integration_Delete(t *testing.T) {
|
// func TestAccounts_Integration_Delete(t *testing.T) {
|
||||||
// withBlackBoxServer(t, func(c *Client) {
|
// withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
// accounts, err := c.Accounts.List(context.Background())
|
// accounts, err := c.Accounts.List(context.Background())
|
||||||
// require.NoError(t, err)
|
// require.NoError(t, err)
|
||||||
// assert.Len(t, accounts, 1)
|
// assert.Len(t, accounts, 1)
|
||||||
|
|||||||
@@ -1,18 +1,22 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
func withMockClient(callback func(*Client, *http.ServeMux)) {
|
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
|
||||||
mux := &http.ServeMux{}
|
mux := &http.ServeMux{}
|
||||||
server := httptest.NewServer(mux)
|
server := httptest.NewServer(mux)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
c := New(server.URL, "ABC")
|
c := rest.New(server.URL, "ABC")
|
||||||
callback(c, mux)
|
callback(c, mux)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,11 +24,11 @@ func ptr[T any, PT *T](x T) PT {
|
|||||||
return &x
|
return &x
|
||||||
}
|
}
|
||||||
|
|
||||||
func withBlackBoxServer(t *testing.T, callback func(*Client)) {
|
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
|
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
|
||||||
server := httptest.NewServer(handler)
|
server := httptest.NewServer(handler)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
c := New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
|
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
|
||||||
callback(c)
|
callback(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -25,7 +30,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSNameserverGroup_List_200(t *testing.T) {
|
func TestDNSNameserverGroup_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
|
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -39,7 +44,7 @@ func TestDNSNameserverGroup_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_List_Err(t *testing.T) {
|
func TestDNSNameserverGroup_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -54,7 +59,7 @@ func TestDNSNameserverGroup_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Get_200(t *testing.T) {
|
func TestDNSNameserverGroup_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testNameserverGroup)
|
retBytes, _ := json.Marshal(testNameserverGroup)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -67,7 +72,7 @@ func TestDNSNameserverGroup_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Get_Err(t *testing.T) {
|
func TestDNSNameserverGroup_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -82,7 +87,7 @@ func TestDNSNameserverGroup_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Create_200(t *testing.T) {
|
func TestDNSNameserverGroup_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -104,7 +109,7 @@ func TestDNSNameserverGroup_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Create_Err(t *testing.T) {
|
func TestDNSNameserverGroup_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -121,7 +126,7 @@ func TestDNSNameserverGroup_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Update_200(t *testing.T) {
|
func TestDNSNameserverGroup_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -143,7 +148,7 @@ func TestDNSNameserverGroup_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Update_Err(t *testing.T) {
|
func TestDNSNameserverGroup_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -160,7 +165,7 @@ func TestDNSNameserverGroup_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Delete_200(t *testing.T) {
|
func TestDNSNameserverGroup_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -171,7 +176,7 @@ func TestDNSNameserverGroup_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
|
func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -185,7 +190,7 @@ func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSSettings_Get_200(t *testing.T) {
|
func TestDNSSettings_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testSettings)
|
retBytes, _ := json.Marshal(testSettings)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -198,7 +203,7 @@ func TestDNSSettings_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSSettings_Get_Err(t *testing.T) {
|
func TestDNSSettings_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -213,7 +218,7 @@ func TestDNSSettings_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSSettings_Update_200(t *testing.T) {
|
func TestDNSSettings_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -235,7 +240,7 @@ func TestDNSSettings_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSSettings_Update_Err(t *testing.T) {
|
func TestDNSSettings_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -267,7 +272,7 @@ func TestDNS_Integration(t *testing.T) {
|
|||||||
Primary: true,
|
Primary: true,
|
||||||
SearchDomainsEnabled: false,
|
SearchDomainsEnabled: false,
|
||||||
}
|
}
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
// Create
|
// Create
|
||||||
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
|
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,10 +9,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -20,7 +25,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestEvents_List_200(t *testing.T) {
|
func TestEvents_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Event{testEvent})
|
retBytes, _ := json.Marshal([]api.Event{testEvent})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -34,7 +39,7 @@ func TestEvents_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEvents_List_Err(t *testing.T) {
|
func TestEvents_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -49,7 +54,7 @@ func TestEvents_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEvents_Integration(t *testing.T) {
|
func TestEvents_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
// Do something that would trigger any event
|
// Do something that would trigger any event
|
||||||
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
|
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
|
||||||
Ephemeral: ptr(true),
|
Ephemeral: ptr(true),
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,10 +9,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -25,7 +30,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGeo_ListCountries_200(t *testing.T) {
|
func TestGeo_ListCountries_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Country{testCountry})
|
retBytes, _ := json.Marshal([]api.Country{testCountry})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -39,7 +44,7 @@ func TestGeo_ListCountries_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGeo_ListCountries_Err(t *testing.T) {
|
func TestGeo_ListCountries_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -54,7 +59,7 @@ func TestGeo_ListCountries_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGeo_ListCountryCities_200(t *testing.T) {
|
func TestGeo_ListCountryCities_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.City{testCity})
|
retBytes, _ := json.Marshal([]api.City{testCity})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -68,7 +73,7 @@ func TestGeo_ListCountryCities_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGeo_ListCountryCities_Err(t *testing.T) {
|
func TestGeo_ListCountryCities_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -84,7 +89,7 @@ func TestGeo_ListCountryCities_Err(t *testing.T) {
|
|||||||
|
|
||||||
func TestGeo_Integration(t *testing.T) {
|
func TestGeo_Integration(t *testing.T) {
|
||||||
// Blackbox is initialized with empty GeoLocations
|
// Blackbox is initialized with empty GeoLocations
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
countries, err := c.GeoLocation.ListCountries(context.Background())
|
countries, err := c.GeoLocation.ListCountries(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, countries)
|
assert.Empty(t, countries)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -22,7 +27,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGroups_List_200(t *testing.T) {
|
func TestGroups_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Group{testGroup})
|
retBytes, _ := json.Marshal([]api.Group{testGroup})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -36,7 +41,7 @@ func TestGroups_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_List_Err(t *testing.T) {
|
func TestGroups_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -51,7 +56,7 @@ func TestGroups_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Get_200(t *testing.T) {
|
func TestGroups_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testGroup)
|
retBytes, _ := json.Marshal(testGroup)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -64,7 +69,7 @@ func TestGroups_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Get_Err(t *testing.T) {
|
func TestGroups_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -79,7 +84,7 @@ func TestGroups_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Create_200(t *testing.T) {
|
func TestGroups_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -101,7 +106,7 @@ func TestGroups_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Create_Err(t *testing.T) {
|
func TestGroups_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -118,7 +123,7 @@ func TestGroups_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Update_200(t *testing.T) {
|
func TestGroups_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -140,7 +145,7 @@ func TestGroups_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Update_Err(t *testing.T) {
|
func TestGroups_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -157,7 +162,7 @@ func TestGroups_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Delete_200(t *testing.T) {
|
func TestGroups_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -168,7 +173,7 @@ func TestGroups_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Delete_Err(t *testing.T) {
|
func TestGroups_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -182,7 +187,7 @@ func TestGroups_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGroups_Integration(t *testing.T) {
|
func TestGroups_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
groups, err := c.Groups.List(context.Background())
|
groups, err := c.Groups.List(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, groups, 1)
|
assert.Len(t, groups, 1)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -30,7 +35,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNetworks_List_200(t *testing.T) {
|
func TestNetworks_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Network{testNetwork})
|
retBytes, _ := json.Marshal([]api.Network{testNetwork})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -44,7 +49,7 @@ func TestNetworks_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_List_Err(t *testing.T) {
|
func TestNetworks_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -59,7 +64,7 @@ func TestNetworks_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Get_200(t *testing.T) {
|
func TestNetworks_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testNetwork)
|
retBytes, _ := json.Marshal(testNetwork)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -72,7 +77,7 @@ func TestNetworks_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Get_Err(t *testing.T) {
|
func TestNetworks_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -87,7 +92,7 @@ func TestNetworks_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Create_200(t *testing.T) {
|
func TestNetworks_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -109,7 +114,7 @@ func TestNetworks_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Create_Err(t *testing.T) {
|
func TestNetworks_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -126,7 +131,7 @@ func TestNetworks_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Update_200(t *testing.T) {
|
func TestNetworks_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -148,7 +153,7 @@ func TestNetworks_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Update_Err(t *testing.T) {
|
func TestNetworks_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -165,7 +170,7 @@ func TestNetworks_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Delete_200(t *testing.T) {
|
func TestNetworks_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -176,7 +181,7 @@ func TestNetworks_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Delete_Err(t *testing.T) {
|
func TestNetworks_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -190,7 +195,7 @@ func TestNetworks_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworks_Integration(t *testing.T) {
|
func TestNetworks_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
|
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
|
||||||
Description: ptr("TestNetwork"),
|
Description: ptr("TestNetwork"),
|
||||||
Name: "Test",
|
Name: "Test",
|
||||||
@@ -216,7 +221,7 @@ func TestNetworks_Integration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_List_200(t *testing.T) {
|
func TestNetworkResources_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
|
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -230,7 +235,7 @@ func TestNetworkResources_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_List_Err(t *testing.T) {
|
func TestNetworkResources_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -245,7 +250,7 @@ func TestNetworkResources_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Get_200(t *testing.T) {
|
func TestNetworkResources_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testNetworkResource)
|
retBytes, _ := json.Marshal(testNetworkResource)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -258,7 +263,7 @@ func TestNetworkResources_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Get_Err(t *testing.T) {
|
func TestNetworkResources_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -273,7 +278,7 @@ func TestNetworkResources_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Create_200(t *testing.T) {
|
func TestNetworkResources_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -295,7 +300,7 @@ func TestNetworkResources_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Create_Err(t *testing.T) {
|
func TestNetworkResources_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -312,7 +317,7 @@ func TestNetworkResources_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Update_200(t *testing.T) {
|
func TestNetworkResources_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -334,7 +339,7 @@ func TestNetworkResources_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Update_Err(t *testing.T) {
|
func TestNetworkResources_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -351,7 +356,7 @@ func TestNetworkResources_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Delete_200(t *testing.T) {
|
func TestNetworkResources_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -362,7 +367,7 @@ func TestNetworkResources_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Delete_Err(t *testing.T) {
|
func TestNetworkResources_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -376,7 +381,7 @@ func TestNetworkResources_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkResources_Integration(t *testing.T) {
|
func TestNetworkResources_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
|
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
|
||||||
Address: "test.com",
|
Address: "test.com",
|
||||||
Description: ptr("Description"),
|
Description: ptr("Description"),
|
||||||
@@ -403,7 +408,7 @@ func TestNetworkResources_Integration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_List_200(t *testing.T) {
|
func TestNetworkRouters_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
|
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -417,7 +422,7 @@ func TestNetworkRouters_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_List_Err(t *testing.T) {
|
func TestNetworkRouters_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -432,7 +437,7 @@ func TestNetworkRouters_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Get_200(t *testing.T) {
|
func TestNetworkRouters_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testNetworkRouter)
|
retBytes, _ := json.Marshal(testNetworkRouter)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -445,7 +450,7 @@ func TestNetworkRouters_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Get_Err(t *testing.T) {
|
func TestNetworkRouters_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -460,7 +465,7 @@ func TestNetworkRouters_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Create_200(t *testing.T) {
|
func TestNetworkRouters_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -482,7 +487,7 @@ func TestNetworkRouters_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Create_Err(t *testing.T) {
|
func TestNetworkRouters_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -499,7 +504,7 @@ func TestNetworkRouters_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Update_200(t *testing.T) {
|
func TestNetworkRouters_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -521,7 +526,7 @@ func TestNetworkRouters_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Update_Err(t *testing.T) {
|
func TestNetworkRouters_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -538,7 +543,7 @@ func TestNetworkRouters_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Delete_200(t *testing.T) {
|
func TestNetworkRouters_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -549,7 +554,7 @@ func TestNetworkRouters_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Delete_Err(t *testing.T) {
|
func TestNetworkRouters_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -563,7 +568,7 @@ func TestNetworkRouters_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkRouters_Integration(t *testing.T) {
|
func TestNetworkRouters_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
|
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -24,7 +29,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPeers_List_200(t *testing.T) {
|
func TestPeers_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -38,7 +43,7 @@ func TestPeers_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_List_Err(t *testing.T) {
|
func TestPeers_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -53,7 +58,7 @@ func TestPeers_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Get_200(t *testing.T) {
|
func TestPeers_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testPeer)
|
retBytes, _ := json.Marshal(testPeer)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -66,7 +71,7 @@ func TestPeers_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Get_Err(t *testing.T) {
|
func TestPeers_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -81,7 +86,7 @@ func TestPeers_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Update_200(t *testing.T) {
|
func TestPeers_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -103,7 +108,7 @@ func TestPeers_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Update_Err(t *testing.T) {
|
func TestPeers_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -120,7 +125,7 @@ func TestPeers_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Delete_200(t *testing.T) {
|
func TestPeers_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -131,7 +136,7 @@ func TestPeers_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Delete_Err(t *testing.T) {
|
func TestPeers_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -145,7 +150,7 @@ func TestPeers_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_ListAccessiblePeers_200(t *testing.T) {
|
func TestPeers_ListAccessiblePeers_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -159,7 +164,7 @@ func TestPeers_ListAccessiblePeers_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
|
func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -174,7 +179,7 @@ func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPeers_Integration(t *testing.T) {
|
func TestPeers_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
peers, err := c.Peers.List(context.Background())
|
peers, err := c.Peers.List(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, peers)
|
require.NotEmpty(t, peers)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -22,7 +27,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPolicies_List_200(t *testing.T) {
|
func TestPolicies_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Policy{testPolicy})
|
retBytes, _ := json.Marshal([]api.Policy{testPolicy})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -36,7 +41,7 @@ func TestPolicies_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_List_Err(t *testing.T) {
|
func TestPolicies_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -51,7 +56,7 @@ func TestPolicies_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Get_200(t *testing.T) {
|
func TestPolicies_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testPolicy)
|
retBytes, _ := json.Marshal(testPolicy)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -64,7 +69,7 @@ func TestPolicies_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Get_Err(t *testing.T) {
|
func TestPolicies_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -79,7 +84,7 @@ func TestPolicies_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Create_200(t *testing.T) {
|
func TestPolicies_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -101,7 +106,7 @@ func TestPolicies_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Create_Err(t *testing.T) {
|
func TestPolicies_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -118,7 +123,7 @@ func TestPolicies_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Update_200(t *testing.T) {
|
func TestPolicies_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -140,7 +145,7 @@ func TestPolicies_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Update_Err(t *testing.T) {
|
func TestPolicies_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -157,7 +162,7 @@ func TestPolicies_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Delete_200(t *testing.T) {
|
func TestPolicies_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -168,7 +173,7 @@ func TestPolicies_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Delete_Err(t *testing.T) {
|
func TestPolicies_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -182,7 +187,7 @@ func TestPolicies_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicies_Integration(t *testing.T) {
|
func TestPolicies_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
policies, err := c.Policies.List(context.Background())
|
policies, err := c.Policies.List(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, policies)
|
require.NotEmpty(t, policies)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -21,7 +26,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPostureChecks_List_200(t *testing.T) {
|
func TestPostureChecks_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
|
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -35,7 +40,7 @@ func TestPostureChecks_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_List_Err(t *testing.T) {
|
func TestPostureChecks_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -50,7 +55,7 @@ func TestPostureChecks_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Get_200(t *testing.T) {
|
func TestPostureChecks_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testPostureCheck)
|
retBytes, _ := json.Marshal(testPostureCheck)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -63,7 +68,7 @@ func TestPostureChecks_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Get_Err(t *testing.T) {
|
func TestPostureChecks_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -78,7 +83,7 @@ func TestPostureChecks_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Create_200(t *testing.T) {
|
func TestPostureChecks_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -100,7 +105,7 @@ func TestPostureChecks_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Create_Err(t *testing.T) {
|
func TestPostureChecks_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -117,7 +122,7 @@ func TestPostureChecks_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Update_200(t *testing.T) {
|
func TestPostureChecks_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -139,7 +144,7 @@ func TestPostureChecks_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Update_Err(t *testing.T) {
|
func TestPostureChecks_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -156,7 +161,7 @@ func TestPostureChecks_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Delete_200(t *testing.T) {
|
func TestPostureChecks_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -167,7 +172,7 @@ func TestPostureChecks_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Delete_Err(t *testing.T) {
|
func TestPostureChecks_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -181,7 +186,7 @@ func TestPostureChecks_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPostureChecks_Integration(t *testing.T) {
|
func TestPostureChecks_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
|
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
|
||||||
Name: "Test",
|
Name: "Test",
|
||||||
Description: "Testing",
|
Description: "Testing",
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -21,7 +26,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRoutes_List_200(t *testing.T) {
|
func TestRoutes_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.Route{testRoute})
|
retBytes, _ := json.Marshal([]api.Route{testRoute})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -35,7 +40,7 @@ func TestRoutes_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_List_Err(t *testing.T) {
|
func TestRoutes_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -50,7 +55,7 @@ func TestRoutes_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Get_200(t *testing.T) {
|
func TestRoutes_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testRoute)
|
retBytes, _ := json.Marshal(testRoute)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -63,7 +68,7 @@ func TestRoutes_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Get_Err(t *testing.T) {
|
func TestRoutes_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -78,7 +83,7 @@ func TestRoutes_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Create_200(t *testing.T) {
|
func TestRoutes_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -100,7 +105,7 @@ func TestRoutes_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Create_Err(t *testing.T) {
|
func TestRoutes_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -117,7 +122,7 @@ func TestRoutes_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Update_200(t *testing.T) {
|
func TestRoutes_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -139,7 +144,7 @@ func TestRoutes_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Update_Err(t *testing.T) {
|
func TestRoutes_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -156,7 +161,7 @@ func TestRoutes_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Delete_200(t *testing.T) {
|
func TestRoutes_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -167,7 +172,7 @@ func TestRoutes_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Delete_Err(t *testing.T) {
|
func TestRoutes_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -181,7 +186,7 @@ func TestRoutes_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutes_Integration(t *testing.T) {
|
func TestRoutes_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
route, err := c.Routes.Create(context.Background(), api.RouteRequest{
|
route, err := c.Routes.Create(context.Background(), api.RouteRequest{
|
||||||
Description: "Meow",
|
Description: "Meow",
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,10 +10,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,7 +36,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSetupKeys_List_200(t *testing.T) {
|
func TestSetupKeys_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
|
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -45,7 +50,7 @@ func TestSetupKeys_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_List_Err(t *testing.T) {
|
func TestSetupKeys_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -60,7 +65,7 @@ func TestSetupKeys_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Get_200(t *testing.T) {
|
func TestSetupKeys_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testSetupKey)
|
retBytes, _ := json.Marshal(testSetupKey)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -73,7 +78,7 @@ func TestSetupKeys_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Get_Err(t *testing.T) {
|
func TestSetupKeys_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -88,7 +93,7 @@ func TestSetupKeys_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Create_200(t *testing.T) {
|
func TestSetupKeys_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -110,7 +115,7 @@ func TestSetupKeys_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Create_Err(t *testing.T) {
|
func TestSetupKeys_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -127,7 +132,7 @@ func TestSetupKeys_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Update_200(t *testing.T) {
|
func TestSetupKeys_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -149,7 +154,7 @@ func TestSetupKeys_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Update_Err(t *testing.T) {
|
func TestSetupKeys_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -166,7 +171,7 @@ func TestSetupKeys_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Delete_200(t *testing.T) {
|
func TestSetupKeys_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -177,7 +182,7 @@ func TestSetupKeys_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Delete_Err(t *testing.T) {
|
func TestSetupKeys_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -191,7 +196,7 @@ func TestSetupKeys_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeys_Integration(t *testing.T) {
|
func TestSetupKeys_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
|
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
|
||||||
Name: "Test",
|
Name: "Test",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -8,10 +11,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,7 +36,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTokens_List_200(t *testing.T) {
|
func TestTokens_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
|
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -45,7 +50,7 @@ func TestTokens_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_List_Err(t *testing.T) {
|
func TestTokens_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -60,7 +65,7 @@ func TestTokens_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Get_200(t *testing.T) {
|
func TestTokens_Get_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(testToken)
|
retBytes, _ := json.Marshal(testToken)
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -73,7 +78,7 @@ func TestTokens_Get_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Get_Err(t *testing.T) {
|
func TestTokens_Get_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -88,7 +93,7 @@ func TestTokens_Get_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Create_200(t *testing.T) {
|
func TestTokens_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -110,7 +115,7 @@ func TestTokens_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Create_Err(t *testing.T) {
|
func TestTokens_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -127,7 +132,7 @@ func TestTokens_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Delete_200(t *testing.T) {
|
func TestTokens_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -138,7 +143,7 @@ func TestTokens_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Delete_Err(t *testing.T) {
|
func TestTokens_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -152,7 +157,7 @@ func TestTokens_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokens_Integration(t *testing.T) {
|
func TestTokens_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
|
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
|
||||||
Name: "Test",
|
Name: "Test",
|
||||||
ExpiresIn: 365,
|
ExpiresIn: 365,
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
package rest
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -8,10 +11,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -34,7 +39,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestUsers_List_200(t *testing.T) {
|
func TestUsers_List_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal([]api.User{testUser})
|
retBytes, _ := json.Marshal([]api.User{testUser})
|
||||||
_, err := w.Write(retBytes)
|
_, err := w.Write(retBytes)
|
||||||
@@ -48,7 +53,7 @@ func TestUsers_List_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_List_Err(t *testing.T) {
|
func TestUsers_List_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -63,7 +68,7 @@ func TestUsers_List_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Create_200(t *testing.T) {
|
func TestUsers_Create_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -85,7 +90,7 @@ func TestUsers_Create_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Create_Err(t *testing.T) {
|
func TestUsers_Create_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -102,7 +107,7 @@ func TestUsers_Create_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Update_200(t *testing.T) {
|
func TestUsers_Update_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "PUT", r.Method)
|
assert.Equal(t, "PUT", r.Method)
|
||||||
reqBytes, err := io.ReadAll(r.Body)
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
@@ -125,7 +130,7 @@ func TestUsers_Update_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Update_Err(t *testing.T) {
|
func TestUsers_Update_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
@@ -142,7 +147,7 @@ func TestUsers_Update_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Delete_200(t *testing.T) {
|
func TestUsers_Delete_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "DELETE", r.Method)
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -153,7 +158,7 @@ func TestUsers_Delete_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Delete_Err(t *testing.T) {
|
func TestUsers_Delete_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -167,7 +172,7 @@ func TestUsers_Delete_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_ResendInvitation_200(t *testing.T) {
|
func TestUsers_ResendInvitation_200(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "POST", r.Method)
|
assert.Equal(t, "POST", r.Method)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -178,7 +183,7 @@ func TestUsers_ResendInvitation_200(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_ResendInvitation_Err(t *testing.T) {
|
func TestUsers_ResendInvitation_Err(t *testing.T) {
|
||||||
withMockClient(func(c *Client, mux *http.ServeMux) {
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
@@ -192,7 +197,7 @@ func TestUsers_ResendInvitation_Err(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsers_Integration(t *testing.T) {
|
func TestUsers_Integration(t *testing.T) {
|
||||||
withBlackBoxServer(t, func(c *Client) {
|
withBlackBoxServer(t, func(c *rest.Client) {
|
||||||
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
|
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
|
||||||
AutoGroups: []string{},
|
AutoGroups: []string{},
|
||||||
Email: ptr("test@example.com"),
|
Email: ptr("test@example.com"),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
@@ -114,6 +115,18 @@ func NewServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||||
|
ip := ""
|
||||||
|
p, ok := peer.FromContext(ctx)
|
||||||
|
if ok {
|
||||||
|
ip = p.Addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
|
||||||
|
}()
|
||||||
|
|
||||||
// todo introduce something more meaningful with the key expiration/rotation
|
// todo introduce something more meaningful with the key expiration/rotation
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
|
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
|
||||||
@@ -717,6 +730,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
|||||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
|
||||||
|
}()
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||||
@@ -769,6 +788,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
|||||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
|
||||||
|
}()
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||||
|
|||||||
@@ -141,7 +141,6 @@ type Client struct {
|
|||||||
muInstanceURL sync.Mutex
|
muInstanceURL sync.Mutex
|
||||||
|
|
||||||
onDisconnectListener func(string)
|
onDisconnectListener func(string)
|
||||||
onConnectedListener func()
|
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +189,6 @@ func (c *Client) Connect() error {
|
|||||||
|
|
||||||
c.wgReadLoop.Add(1)
|
c.wgReadLoop.Add(1)
|
||||||
go c.readLoop(c.relayConn)
|
go c.readLoop(c.relayConn)
|
||||||
go c.notifyConnected()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -238,12 +236,6 @@ func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
|||||||
c.onDisconnectListener = fn
|
c.onDisconnectListener = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SetOnConnectedListener(fn func()) {
|
|
||||||
c.listenerMutex.Lock()
|
|
||||||
defer c.listenerMutex.Unlock()
|
|
||||||
c.onConnectedListener = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasConns returns true if there are connections.
|
// HasConns returns true if there are connections.
|
||||||
func (c *Client) HasConns() bool {
|
func (c *Client) HasConns() bool {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -559,16 +551,6 @@ func (c *Client) notifyDisconnected() {
|
|||||||
go c.onDisconnectListener(c.connectionURL)
|
go c.onDisconnectListener(c.connectionURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) notifyConnected() {
|
|
||||||
c.listenerMutex.Lock()
|
|
||||||
defer c.listenerMutex.Unlock()
|
|
||||||
|
|
||||||
if c.onConnectedListener == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go c.onConnectedListener()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) writeCloseMsg() {
|
func (c *Client) writeCloseMsg() {
|
||||||
msg := messages.MarshalCloseMsg()
|
msg := messages.MarshalCloseMsg()
|
||||||
_, err := c.relayConn.Write(msg)
|
_, err := c.relayConn.Write(msg)
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ var (
|
|||||||
|
|
||||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||||
type Guard struct {
|
type Guard struct {
|
||||||
// OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
|
// OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
|
||||||
OnNewRelayClient chan *Client
|
OnNewRelayClient chan *Client
|
||||||
|
OnReconnected chan struct{}
|
||||||
serverPicker *ServerPicker
|
serverPicker *ServerPicker
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,6 +24,7 @@ type Guard struct {
|
|||||||
func NewGuard(sp *ServerPicker) *Guard {
|
func NewGuard(sp *ServerPicker) *Guard {
|
||||||
g := &Guard{
|
g := &Guard{
|
||||||
OnNewRelayClient: make(chan *Client, 1),
|
OnNewRelayClient: make(chan *Client, 1),
|
||||||
|
OnReconnected: make(chan struct{}, 1),
|
||||||
serverPicker: sp,
|
serverPicker: sp,
|
||||||
}
|
}
|
||||||
return g
|
return g
|
||||||
@@ -39,14 +41,13 @@ func NewGuard(sp *ServerPicker) *Guard {
|
|||||||
// - relayClient: The relay client instance that was disconnected.
|
// - relayClient: The relay client instance that was disconnected.
|
||||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||||
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
||||||
if relayClient == nil {
|
// try to reconnect to the same server
|
||||||
goto RETRY
|
if ok := g.tryToQuickReconnect(ctx, relayClient); ok {
|
||||||
}
|
g.notifyReconnected()
|
||||||
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
RETRY:
|
// start a ticker to pick a new server
|
||||||
ticker := exponentTicker(ctx)
|
ticker := exponentTicker(ctx)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -64,6 +65,28 @@ RETRY:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool {
|
||||||
|
if rc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !g.isServerURLStillValid(rc) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if cancelled := waiteBeforeRetry(parentCtx); !cancelled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||||
|
|
||||||
|
if err := rc.Connect(); err != nil {
|
||||||
|
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (g *Guard) retry(ctx context.Context) error {
|
func (g *Guard) retry(ctx context.Context) error {
|
||||||
log.Infof("try to pick up a new Relay server")
|
log.Infof("try to pick up a new Relay server")
|
||||||
relayClient, err := g.serverPicker.PickServer(ctx)
|
relayClient, err := g.serverPicker.PickServer(ctx)
|
||||||
@@ -78,23 +101,6 @@ func (g *Guard) retry(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
|
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
<-ctx.Done()
|
|
||||||
|
|
||||||
if parentCtx.Err() != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
|
||||||
|
|
||||||
if err := rc.Connect(); err != nil {
|
|
||||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Guard) drainRelayClientChan() {
|
func (g *Guard) drainRelayClientChan() {
|
||||||
select {
|
select {
|
||||||
case <-g.OnNewRelayClient:
|
case <-g.OnNewRelayClient:
|
||||||
@@ -111,6 +117,13 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Guard) notifyReconnected() {
|
||||||
|
select {
|
||||||
|
case g.OnReconnected <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func exponentTicker(ctx context.Context) *backoff.Ticker {
|
func exponentTicker(ctx context.Context) *backoff.Ticker {
|
||||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 2 * time.Second,
|
InitialInterval: 2 * time.Second,
|
||||||
@@ -121,3 +134,15 @@ func exponentTicker(ctx context.Context) *backoff.Ticker {
|
|||||||
|
|
||||||
return backoff.NewTicker(bo)
|
return backoff.NewTicker(bo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func waiteBeforeRetry(ctx context.Context) bool {
|
||||||
|
timer := time.NewTimer(1500 * time.Millisecond)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
return true
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user