mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-24 17:49:55 +00:00
Compare commits
3 Commits
pcp-androi
...
refactor-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bcc5b617d | ||
|
|
449a80d5ff | ||
|
|
fc5540926f |
@@ -43,16 +43,16 @@ func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -45,11 +45,8 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
|
||||
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
|
||||
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
|
||||
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
@@ -168,33 +165,22 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
|
||||
Firewall Rules (Linux only)
|
||||
The bundle includes the following firewall-related files:
|
||||
The bundle includes two separate firewall rule files:
|
||||
|
||||
iptables.txt:
|
||||
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
|
||||
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
|
||||
- Includes all tables (filter, nat, mangle, raw, security)
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
ip6tables.txt:
|
||||
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
|
||||
- Same table coverage and anonymization as iptables.txt
|
||||
- Omitted when ip6tables is not installed or no IPv6 rules are present
|
||||
|
||||
ipset.txt:
|
||||
- Output of 'ipset list' (family-agnostic)
|
||||
- IP addresses are anonymized; set names and types remain unchanged
|
||||
|
||||
nftables.txt:
|
||||
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
|
||||
- Complete nftables ruleset obtained via 'nft -a list ruleset'
|
||||
- Includes rule handle numbers and packet counters
|
||||
- All IP addresses are anonymized; chain/table names remain unchanged
|
||||
|
||||
sysctls.txt:
|
||||
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
|
||||
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
|
||||
- Interface names anonymized when --anonymize is set
|
||||
- All tables, chains, and rules are included
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
IP Rules (Linux only)
|
||||
The ip_rules.txt file contains detailed IP routing rule information:
|
||||
@@ -426,10 +412,6 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSysctls(); err != nil {
|
||||
log.Errorf("failed to add sysctls to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
@@ -124,18 +124,15 @@ func getSystemdLogs(serviceName string) (string, error) {
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
|
||||
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
} else {
|
||||
if g.anonymize {
|
||||
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
|
||||
log.Warnf("Failed to add ipset output to bundle: %v", err)
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,65 +151,44 @@ func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
|
||||
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
|
||||
rules, err := collectIPTablesRules(saveBin, listBin)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect %s rules: %v", listBin, err)
|
||||
return
|
||||
}
|
||||
if g.anonymize {
|
||||
rules = g.anonymizer.AnonymizeString(rules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
|
||||
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
|
||||
// Returns an error when neither command produced any output (e.g. the binary is missing),
|
||||
// so the caller can skip writing an empty file.
|
||||
func collectIPTablesRules(saveBin, listBin string) (string, error) {
|
||||
// collectIPTablesRules collects rules using both iptables-save and verbose listing
|
||||
func collectIPTablesRules() (string, error) {
|
||||
var builder strings.Builder
|
||||
var collected bool
|
||||
var firstErr error
|
||||
|
||||
saveOutput, err := runCommand(saveBin)
|
||||
switch {
|
||||
case err != nil:
|
||||
firstErr = err
|
||||
log.Warnf("Failed to collect %s output: %v", saveBin, err)
|
||||
case strings.TrimSpace(saveOutput) == "":
|
||||
log.Debugf("%s produced no output, skipping", saveBin)
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
|
||||
saveOutput, err := collectIPTablesSave()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== iptables-save output ===\n")
|
||||
builder.WriteString(saveOutput)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
|
||||
builder.WriteString(listHeader)
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
stats, err := getTableStatistics(table)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
builder.WriteString(stats)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
if !collected {
|
||||
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
@@ -238,15 +214,34 @@ func collectIPSets() (string, error) {
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
|
||||
func runCommand(name string, args ...string) (string, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
|
||||
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
rules := stdout.String()
|
||||
if strings.TrimSpace(rules) == "" {
|
||||
return "", fmt.Errorf("no iptables rules found")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getTableStatistics gets verbose statistics for an entire table using iptables command
|
||||
func getTableStatistics(table string) (string, error) {
|
||||
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
@@ -809,91 +804,3 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
|
||||
return fmt.Sprintf("type-%v", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
log.Info("Collecting sysctls")
|
||||
content := collectSysctls()
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
|
||||
return fmt.Errorf("add sysctls to bundle: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSysctls reads every sysctl that the netbird client may modify, plus
|
||||
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
|
||||
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
|
||||
func collectSysctls() string {
|
||||
var builder strings.Builder
|
||||
|
||||
writeSysctlGroup(&builder, "forwarding", []string{
|
||||
"net.ipv4.ip_forward",
|
||||
"net.ipv6.conf.all.forwarding",
|
||||
"net.ipv6.conf.default.forwarding",
|
||||
})
|
||||
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
|
||||
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
|
||||
writeSysctlGroup(&builder, "rp_filter", append(
|
||||
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
|
||||
listInterfaceSysctls("ipv4", "rp_filter")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "src_valid_mark", append(
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
})
|
||||
writeSysctlGroup(&builder, "tcp", []string{
|
||||
"net.ipv4.tcp_tw_reuse",
|
||||
})
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
|
||||
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
|
||||
for _, key := range keys {
|
||||
value, err := readSysctl(key)
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
|
||||
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
|
||||
// (callers add those explicitly so they appear first).
|
||||
func listInterfaceSysctls(family, leaf string) []string {
|
||||
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == "all" || name == "default" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func readSysctl(key string) (string, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
value, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(value)), nil
|
||||
}
|
||||
|
||||
@@ -17,8 +17,3 @@ func (g *BundleGenerator) addIPRules() error {
|
||||
// IP rules are only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
// Sysctl collection is only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -179,10 +179,8 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv4zero
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
|
||||
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
|
||||
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
|
||||
if runtime.GOOS == "linux" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
|
||||
dst = net.IPv4(0, 0, 0, 1)
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
@@ -205,7 +203,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv6zero
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if runtime.GOOS == "linux" {
|
||||
// ::2
|
||||
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
}
|
||||
|
||||
@@ -67,6 +67,10 @@ func init() {
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
@@ -168,7 +172,7 @@ func initializeConfig() error {
|
||||
// serverInstances holds all server instances created during startup.
|
||||
type serverInstances struct {
|
||||
relaySrv *relayServer.Server
|
||||
mgmtSrv *mgmtServer.BaseServer
|
||||
mgmtSrv mgmtServer.Server
|
||||
signalSrv *signalServer.Server
|
||||
healthcheck *healthcheck.Server
|
||||
stunServer *stun.Server
|
||||
@@ -324,19 +328,24 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok {
|
||||
if baseServer, ok := s.(*mgmtServer.BaseServer); ok {
|
||||
baseServer.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
|
||||
@@ -346,38 +355,32 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *
|
||||
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
if stunServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error {
|
||||
var errs error
|
||||
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
@@ -491,7 +494,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) {
|
||||
mgmt := cfg.Management
|
||||
|
||||
// Extract port from listen address
|
||||
@@ -502,7 +505,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
mgmtSrv := newServer(
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: "",
|
||||
|
||||
13
combined/cmd/server.go
Normal file
13
combined/cmd/server.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
|
||||
)
|
||||
|
||||
var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server {
|
||||
return mgmtServer.NewServer(cfg)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
@@ -304,27 +304,10 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get public cluster addresses: %w", err)
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
|
||||
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
|
||||
for _, addr := range byopAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
for _, addr := range publicAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
return merged, nil
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
|
||||
@@ -40,37 +40,22 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"shared.example.com", "byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
@@ -94,6 +79,10 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
@@ -103,23 +92,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "public cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
@@ -136,19 +108,3 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -306,10 +306,6 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
@@ -325,6 +321,10 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -435,10 +435,6 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
@@ -452,6 +448,10 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -552,22 +552,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
svcForCaps.ProxyCluster = effectiveCluster
|
||||
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate subdomain requirement *before* the transaction: the underlying
|
||||
// capability lookup talks to the main DB pool, and SQLite's single-connection
|
||||
// pool would self-deadlock if this ran while the tx already held the only
|
||||
// connection.
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
@@ -597,7 +585,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -615,13 +603,17 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
@@ -636,6 +628,9 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -643,18 +638,20 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDomainChange validates the new domain is free inside the transaction
|
||||
// and applies the pre-resolved cluster (computed outside the tx by
|
||||
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
|
||||
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
|
||||
// because the transaction already holds the only connection.
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if effectiveCluster != "" {
|
||||
svc.ProxyCluster = effectiveCluster
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
svc.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -381,14 +381,13 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
}
|
||||
|
||||
// HTTP/HTTPS: build full URL
|
||||
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: bracketIPv6Host(hostNoBrackets),
|
||||
Host: target.Host,
|
||||
Path: "/",
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
@@ -406,19 +405,6 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
return pathMappings
|
||||
}
|
||||
|
||||
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
|
||||
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
|
||||
// addresses are bracketed too since their textual form contains colons.
|
||||
func bracketIPv6Host(host string) string {
|
||||
if strings.HasPrefix(host, "[") {
|
||||
return host
|
||||
}
|
||||
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
|
||||
return "[" + host + "]"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
|
||||
@@ -351,83 +351,6 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
port: 80,
|
||||
wantTarget: "https://10.0.0.1:80/",
|
||||
},
|
||||
{
|
||||
name: "domain host without port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 0,
|
||||
wantTarget: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "domain host with non-default port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 8080,
|
||||
wantTarget: "http://example.com:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with default port omits port and brackets host",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 80,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with non-default port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe:1::3]:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::1]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with 5-digit port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe::1",
|
||||
port: 18080,
|
||||
wantTarget: "http://[fb00:cafe::1]:18080/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 without port stays single-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe::1]/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 with port is not double-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe::1]:8080/",
|
||||
},
|
||||
{
|
||||
name: "v4-mapped ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::ffff:10.0.0.1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::ffff:10.0.0.1]/",
|
||||
},
|
||||
{
|
||||
name: "full-form 8-group ipv6 without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1:0:0:0:0:3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -34,6 +34,8 @@ const (
|
||||
ManagementLegacyPort = 33073
|
||||
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||
|
||||
ContainerKeyBaseServer = "baseServer"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
@@ -91,7 +93,7 @@ type Config struct {
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(cfg *Config) *BaseServer {
|
||||
return &BaseServer{
|
||||
s := &BaseServer{
|
||||
Config: cfg.NbConfig,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: cfg.DNSDomain,
|
||||
@@ -104,6 +106,9 @@ func NewServer(cfg *Config) *BaseServer {
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
autoResolveDomains: cfg.AutoResolveDomains,
|
||||
}
|
||||
s.container[ContainerKeyBaseServer] = s
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
|
||||
|
||||
@@ -2487,18 +2487,6 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran
|
||||
allowedPeers[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Embedded proxy peers sit outside regular group membership but must
|
||||
// participate in any v6-enabled overlay to reach v6-only peers.
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get peers: %w", err)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.ProxyMeta.Embedded {
|
||||
allowedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return allowedPeers, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -762,19 +762,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
newPeer.IP = freeIP
|
||||
|
||||
if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil {
|
||||
// Embedded proxy peers are not group members but participate in any
|
||||
// IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers.
|
||||
allocate := peer.ProxyMeta.Embedded
|
||||
if !allocate {
|
||||
var allGroupID string
|
||||
if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil {
|
||||
allGroupID = allGroup.ID
|
||||
} else {
|
||||
var allGroupID string
|
||||
if !peer.ProxyMeta.Embedded {
|
||||
allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All")
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err)
|
||||
} else {
|
||||
allGroupID = allGroup.ID
|
||||
}
|
||||
allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID)
|
||||
}
|
||||
if allocate {
|
||||
if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) {
|
||||
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -2795,27 +2794,12 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
connStr = filepath.Join(dataDir, filePath)
|
||||
}
|
||||
|
||||
// Compose query parameters. User-provided ?_busy_timeout (or its mattn alias
|
||||
// ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at
|
||||
// most that long on a lock instead of blocking the only Go-side connection.
|
||||
// mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so
|
||||
// the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared
|
||||
// stays the default on non-Windows for the same reason as before.
|
||||
parsed, _ := url.ParseQuery(query)
|
||||
var defaults []string
|
||||
if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" {
|
||||
defaults = append(defaults, "_busy_timeout=30000")
|
||||
}
|
||||
if !hasQuery && runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
defaults = append(defaults, "cache=shared")
|
||||
}
|
||||
parts := defaults
|
||||
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
||||
if hasQuery {
|
||||
parts = append(parts, query)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
connStr += "?" + strings.Join(parts, "&")
|
||||
connStr += "?" + query
|
||||
} else if runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
connStr += "?cache=shared"
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
||||
@@ -3418,7 +3402,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout)
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -4592,55 +4592,3 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies
|
||||
// that the _busy_timeout DSN parameter took effect at the driver level. Without
|
||||
// this, lock contention on the single SQLite connection waits indefinitely on
|
||||
// the Go side and can be hidden behind the 5-minute transactionTimeout.
|
||||
func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling")
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator
|
||||
// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE
|
||||
// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore.
|
||||
func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
envFile string
|
||||
expected int
|
||||
}{
|
||||
{name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000},
|
||||
{name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile)
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, tc.expected, busyTimeout)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,21 +598,28 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
return groupList
|
||||
}
|
||||
|
||||
// PeerIPv6Allowed reports whether the given peer participates in the IPv6 overlay.
|
||||
// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups.
|
||||
// Returns false if IPv6 is disabled or no groups are configured.
|
||||
func (a *Account) PeerIPv6Allowed(peerID string) bool {
|
||||
_, ok := a.peerIPv6AllowedSet()[peerID]
|
||||
return ok
|
||||
if len(a.Settings.IPv6EnabledGroups) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, groupID := range a.Settings.IPv6EnabledGroups {
|
||||
group, ok := a.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// peerIPv6AllowedSet returns the set of peer IDs that participate in the IPv6 overlay:
|
||||
// members of any IPv6-enabled group, plus every embedded proxy peer (which sit outside
|
||||
// regular group membership but must reach v6-enabled peers).
|
||||
// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group.
|
||||
func (a *Account) peerIPv6AllowedSet() map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
if len(a.Settings.IPv6EnabledGroups) == 0 {
|
||||
return result
|
||||
}
|
||||
for _, groupID := range a.Settings.IPv6EnabledGroups {
|
||||
group, ok := a.Groups[groupID]
|
||||
if !ok {
|
||||
@@ -622,11 +629,6 @@ func (a *Account) peerIPv6AllowedSet() map[string]struct{} {
|
||||
result[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for id, p := range a.Peers {
|
||||
if p != nil && p.ProxyMeta.Embedded {
|
||||
result[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -92,12 +92,9 @@ func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
// GroupAllName is the reserved name of the default group that contains every peer in an account.
|
||||
const GroupAllName = "All"
|
||||
|
||||
// IsGroupAll checks if the group is a default "All" group.
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == GroupAllName
|
||||
return g.Name == "All"
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
|
||||
@@ -232,33 +232,3 @@ func TestIPv6RecalculationOnGroupChange(t *testing.T) {
|
||||
assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeerIPv6AllowedEmbeddedProxy(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1"},
|
||||
"proxy": {ID: "proxy", ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "netbird.test"}},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"group-devs": {ID: "group-devs", Peers: []string{"peer1"}},
|
||||
},
|
||||
Settings: &Settings{},
|
||||
}
|
||||
|
||||
t.Run("embedded proxy allowed when any v6 group exists, without group membership", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = []string{"group-devs"}
|
||||
assert.True(t, account.PeerIPv6Allowed("proxy"), "embedded proxy participates in v6 overlay")
|
||||
assert.True(t, account.PeerIPv6Allowed("peer1"), "regular peer in enabled group still allowed")
|
||||
})
|
||||
|
||||
t.Run("embedded proxy denied when no v6 group enabled", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = nil
|
||||
assert.False(t, account.PeerIPv6Allowed("proxy"), "v6 disabled account-wide denies embedded proxies too")
|
||||
})
|
||||
|
||||
t.Run("non-embedded peer outside any enabled group is not pulled in", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = []string{"group-devs"}
|
||||
account.Peers["lonely"] = &nbpeer.Peer{ID: "lonely"}
|
||||
assert.False(t, account.PeerIPv6Allowed("lonely"), "embedded-proxy bypass must not leak to regular peers")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user