mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-24 01:29:56 +00:00
Compare commits
6 Commits
wasm-webso
...
v0.71.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07e5450117 | ||
|
|
3f914090cb | ||
|
|
ea9fab4396 | ||
|
|
77b479286e | ||
|
|
ab2a8794e7 | ||
|
|
9126a192ca |
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 59768832 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 57MB limit!"
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -43,16 +43,16 @@ func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
|
||||
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
|
||||
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
|
||||
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
|
||||
Firewall Rules (Linux only)
|
||||
The bundle includes two separate firewall rule files:
|
||||
The bundle includes the following firewall-related files:
|
||||
|
||||
iptables.txt:
|
||||
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
|
||||
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
|
||||
- Includes all tables (filter, nat, mangle, raw, security)
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
ip6tables.txt:
|
||||
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
|
||||
- Same table coverage and anonymization as iptables.txt
|
||||
- Omitted when ip6tables is not installed or no IPv6 rules are present
|
||||
|
||||
ipset.txt:
|
||||
- Output of 'ipset list' (family-agnostic)
|
||||
- IP addresses are anonymized; set names and types remain unchanged
|
||||
|
||||
nftables.txt:
|
||||
- Complete nftables ruleset obtained via 'nft -a list ruleset'
|
||||
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
|
||||
- Includes rule handle numbers and packet counters
|
||||
- All tables, chains, and rules are included
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
- All IP addresses are anonymized; chain/table names remain unchanged
|
||||
|
||||
sysctls.txt:
|
||||
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
|
||||
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
|
||||
- Interface names anonymized when --anonymize is set
|
||||
|
||||
IP Rules (Linux only)
|
||||
The ip_rules.txt file contains detailed IP routing rule information:
|
||||
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSysctls(); err != nil {
|
||||
log.Errorf("failed to add sysctls to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
|
||||
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
if g.anonymize {
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
|
||||
log.Warnf("Failed to add ipset output to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both iptables-save and verbose listing
|
||||
func collectIPTablesRules() (string, error) {
|
||||
var builder strings.Builder
|
||||
|
||||
saveOutput, err := collectIPTablesSave()
|
||||
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
|
||||
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
|
||||
rules, err := collectIPTablesRules(saveBin, listBin)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== iptables-save output ===\n")
|
||||
log.Warnf("Failed to collect %s rules: %v", listBin, err)
|
||||
return
|
||||
}
|
||||
if g.anonymize {
|
||||
rules = g.anonymizer.AnonymizeString(rules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
|
||||
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
|
||||
// Returns an error when neither command produced any output (e.g. the binary is missing),
|
||||
// so the caller can skip writing an empty file.
|
||||
func collectIPTablesRules(saveBin, listBin string) (string, error) {
|
||||
var builder strings.Builder
|
||||
var collected bool
|
||||
var firstErr error
|
||||
|
||||
saveOutput, err := runCommand(saveBin)
|
||||
switch {
|
||||
case err != nil:
|
||||
firstErr = err
|
||||
log.Warnf("Failed to collect %s output: %v", saveBin, err)
|
||||
case strings.TrimSpace(saveOutput) == "":
|
||||
log.Debugf("%s produced no output, skipping", saveBin)
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
|
||||
builder.WriteString(saveOutput)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
|
||||
builder.WriteString(listHeader)
|
||||
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
stats, err := getTableStatistics(table)
|
||||
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
builder.WriteString(stats)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
if !collected {
|
||||
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
|
||||
func runCommand(name string, args ...string) (string, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
rules := stdout.String()
|
||||
if strings.TrimSpace(rules) == "" {
|
||||
return "", fmt.Errorf("no iptables rules found")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getTableStatistics gets verbose statistics for an entire table using iptables command
|
||||
func getTableStatistics(table string) (string, error) {
|
||||
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
|
||||
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
|
||||
return fmt.Sprintf("type-%v", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
log.Info("Collecting sysctls")
|
||||
content := collectSysctls()
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
|
||||
return fmt.Errorf("add sysctls to bundle: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSysctls reads every sysctl that the netbird client may modify, plus
|
||||
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
|
||||
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
|
||||
func collectSysctls() string {
|
||||
var builder strings.Builder
|
||||
|
||||
writeSysctlGroup(&builder, "forwarding", []string{
|
||||
"net.ipv4.ip_forward",
|
||||
"net.ipv6.conf.all.forwarding",
|
||||
"net.ipv6.conf.default.forwarding",
|
||||
})
|
||||
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
|
||||
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
|
||||
writeSysctlGroup(&builder, "rp_filter", append(
|
||||
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
|
||||
listInterfaceSysctls("ipv4", "rp_filter")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "src_valid_mark", append(
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
})
|
||||
writeSysctlGroup(&builder, "tcp", []string{
|
||||
"net.ipv4.tcp_tw_reuse",
|
||||
})
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
|
||||
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
|
||||
for _, key := range keys {
|
||||
value, err := readSysctl(key)
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
|
||||
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
|
||||
// (callers add those explicitly so they appear first).
|
||||
func listInterfaceSysctls(family, leaf string) []string {
|
||||
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == "all" || name == "default" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func readSysctl(key string) (string, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
value, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(value)), nil
|
||||
}
|
||||
|
||||
@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
|
||||
// IP rules are only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
// Sysctl collection is only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -252,6 +252,10 @@ func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0644); err != nil {
|
||||
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -31,7 +30,6 @@ const (
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
dialWebSocketTimeout = 30 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
@@ -679,7 +677,6 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["dialWebSocket"] = createDialWebSocketMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
@@ -694,74 +691,6 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
func createDialWebSocketMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
|
||||
if !errVal.IsUndefined() {
|
||||
return errVal
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
|
||||
if len(args) < 1 || args[0].Type() != js.TypeString {
|
||||
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
|
||||
}
|
||||
url = args[0].String()
|
||||
|
||||
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
|
||||
arr, err := jsStringArray(args[1])
|
||||
if err != nil {
|
||||
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
|
||||
}
|
||||
protocols = arr
|
||||
}
|
||||
|
||||
timeout = dialWebSocketTimeout
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
if args[2].Type() != js.TypeNumber {
|
||||
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
|
||||
}
|
||||
timeoutMs := args[2].Int()
|
||||
if timeoutMs <= 0 {
|
||||
return "", nil, 0, js.ValueOf("error: timeout must be positive")
|
||||
}
|
||||
timeout = time.Duration(timeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
return url, protocols, timeout, js.Undefined()
|
||||
}
|
||||
|
||||
// jsStringArray converts a JS array of strings to a Go []string.
|
||||
func jsStringArray(v js.Value) ([]string, error) {
|
||||
if !v.InstanceOf(js.Global().Get("Array")) {
|
||||
return nil, fmt.Errorf("expected array")
|
||||
}
|
||||
n := v.Length()
|
||||
out := make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
el := v.Index(i)
|
||||
if el.Type() != js.TypeString {
|
||||
return nil, fmt.Errorf("element %d is not a string", i)
|
||||
}
|
||||
out[i] = el.String()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type closeError struct {
|
||||
code uint16
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *closeError) Error() string {
|
||||
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
|
||||
}
|
||||
|
||||
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
|
||||
// during the WebSocket handshake before falling through to the raw conn.
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
||||
|
||||
// Conn wraps a WebSocket connection over a NetBird TCP connection.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
|
||||
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
|
||||
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
|
||||
d := ws.Dialer{
|
||||
NetDial: client.Dial,
|
||||
Protocols: protocols,
|
||||
}
|
||||
|
||||
conn, br, _, err := d.Dial(ctx, rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
|
||||
// br is non-nil when the server pushed frames alongside the handshake
|
||||
// response; those bytes live in the bufio.Reader and must be drained
|
||||
// before reading from conn, otherwise we'd skip the first frames.
|
||||
if br != nil {
|
||||
if br.Buffered() > 0 {
|
||||
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
|
||||
} else {
|
||||
ws.PutReader(br)
|
||||
}
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads the next WebSocket message, handling control frames automatically.
|
||||
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
|
||||
for {
|
||||
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
if msg.OpCode.IsControl() {
|
||||
if err := c.handleControl(msg); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return msg.OpCode, msg.Payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handleControl(msg wsutil.Message) error {
|
||||
switch msg.OpCode {
|
||||
case ws.OpPing:
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
|
||||
case ws.OpClose:
|
||||
code, reason := parseClosePayload(msg.Payload)
|
||||
return &closeError{code: code, reason: reason}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText sends a text WebSocket message.
|
||||
func (c *Conn) WriteText(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
|
||||
}
|
||||
|
||||
// WriteBinary sends a binary WebSocket message.
|
||||
func (c *Conn) WriteBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
|
||||
}
|
||||
|
||||
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
|
||||
func (c *Conn) Close() error {
|
||||
return c.closeWith(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
|
||||
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
|
||||
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
|
||||
var first bool
|
||||
c.closeOnce.Do(func() {
|
||||
first = true
|
||||
close(c.closed)
|
||||
|
||||
c.mu.Lock()
|
||||
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
if !first {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
|
||||
// It exposes: send(string|Uint8Array), close(), and callback properties
|
||||
// onmessage, onclose, onerror.
|
||||
//
|
||||
// Callback properties may be set from the JS thread while the read loop
|
||||
// goroutine reads them. In WASM this is safe because Go and JS share a
|
||||
// single thread, but the design would need synchronization on
|
||||
// multi-threaded runtimes.
|
||||
func NewJSInterface(conn *Conn) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
log.Errorf("websocket send requires a data argument")
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
switch data.Type() {
|
||||
case js.TypeString:
|
||||
if err := conn.WriteText([]byte(data.String())); err != nil {
|
||||
log.Errorf("failed to send websocket text: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
default:
|
||||
buf, err := jsToBytes(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert js value to bytes: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
if err := conn.WriteBinary(buf); err != nil {
|
||||
log.Errorf("failed to send websocket binary: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
}
|
||||
return js.ValueOf(true)
|
||||
})
|
||||
obj.Set("send", sendFunc)
|
||||
|
||||
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close websocket: %v", err)
|
||||
}
|
||||
return js.Undefined()
|
||||
})
|
||||
obj.Set("close", closeFunc)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("close websocket on readLoop exit: %v", err)
|
||||
}
|
||||
}()
|
||||
readLoop(conn, obj)
|
||||
// Undefining before Release turns post-close JS calls into TypeError
|
||||
// instead of a silent "call to released function".
|
||||
obj.Set("send", js.Undefined())
|
||||
obj.Set("close", js.Undefined())
|
||||
sendFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
func jsToBytes(data js.Value) ([]byte, error) {
|
||||
var uint8Array js.Value
|
||||
switch {
|
||||
case data.InstanceOf(js.Global().Get("Uint8Array")):
|
||||
uint8Array = data
|
||||
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
|
||||
uint8Array = js.Global().Get("Uint8Array").New(data)
|
||||
default:
|
||||
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
|
||||
}
|
||||
|
||||
buf := make([]byte, uint8Array.Get("length").Int())
|
||||
js.CopyBytesToGo(buf, uint8Array)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func readLoop(conn *Conn, obj js.Value) {
|
||||
var ce *closeError
|
||||
defer func() { invokeOnClose(obj, ce) }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-conn.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
op, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
ce = handleReadError(conn, obj, err)
|
||||
return
|
||||
}
|
||||
|
||||
dispatchMessage(obj, op, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
|
||||
var ce *closeError
|
||||
if errors.As(err, &ce) {
|
||||
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
|
||||
log.Debugf("failed to close websocket after server close frame: %v", cerr)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if onerror := obj.Get("onerror"); onerror.Truthy() {
|
||||
onerror.Invoke(js.ValueOf(err.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func invokeOnClose(obj js.Value, ce *closeError) {
|
||||
onclose := obj.Get("onclose")
|
||||
if !onclose.Truthy() {
|
||||
return
|
||||
}
|
||||
if ce != nil {
|
||||
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
|
||||
return
|
||||
}
|
||||
onclose.Invoke()
|
||||
}
|
||||
|
||||
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
|
||||
onmessage := obj.Get("onmessage")
|
||||
if !onmessage.Truthy() {
|
||||
return
|
||||
}
|
||||
switch op {
|
||||
case ws.OpText:
|
||||
onmessage.Invoke(js.ValueOf(string(payload)))
|
||||
case ws.OpBinary:
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
|
||||
js.CopyBytesToJS(uint8Array, payload)
|
||||
onmessage.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
|
||||
func parseClosePayload(payload []byte) (uint16, string) {
|
||||
if len(payload) < 2 {
|
||||
return 1005, "" // RFC 6455: No Status Rcvd
|
||||
}
|
||||
code := binary.BigEndian.Uint16(payload[:2])
|
||||
return code, string(payload[2:])
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -54,7 +54,6 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/gobwas/ws v1.4.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -212,8 +211,6 @@ require (
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
|
||||
7
go.sum
7
go.sum
@@ -247,12 +247,6 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
|
||||
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
|
||||
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
|
||||
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
|
||||
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -846,7 +840,6 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get public cluster addresses: %w", err)
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
|
||||
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
|
||||
for _, addr := range byopAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
for _, addr := range publicAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
|
||||
@@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"shared.example.com", "byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
@@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
@@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "public cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
@@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
@@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
@@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
svcForCaps.ProxyCluster = effectiveCluster
|
||||
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate subdomain requirement *before* the transaction: the underlying
|
||||
// capability lookup talks to the main DB pool, and SQLite's single-connection
|
||||
// pool would self-deadlock if this ran while the tx already held the only
|
||||
// connection.
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
@@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
@@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
|
||||
// handleDomainChange validates the new domain is free inside the transaction
|
||||
// and applies the pre-resolved cluster (computed outside the tx by
|
||||
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
|
||||
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
|
||||
// because the transaction already holds the only connection.
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
svc.ProxyCluster = newCluster
|
||||
}
|
||||
if effectiveCluster != "" {
|
||||
svc.ProxyCluster = effectiveCluster
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
}
|
||||
|
||||
// HTTP/HTTPS: build full URL
|
||||
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Host: bracketIPv6Host(hostNoBrackets),
|
||||
Path: "/",
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
|
||||
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
@@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
return pathMappings
|
||||
}
|
||||
|
||||
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
|
||||
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
|
||||
// addresses are bracketed too since their textual form contains colons.
|
||||
func bracketIPv6Host(host string) string {
|
||||
if strings.HasPrefix(host, "[") {
|
||||
return host
|
||||
}
|
||||
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
|
||||
return "[" + host + "]"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
|
||||
@@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
port: 80,
|
||||
wantTarget: "https://10.0.0.1:80/",
|
||||
},
|
||||
{
|
||||
name: "domain host without port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 0,
|
||||
wantTarget: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "domain host with non-default port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 8080,
|
||||
wantTarget: "http://example.com:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with default port omits port and brackets host",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 80,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with non-default port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe:1::3]:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::1]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with 5-digit port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe::1",
|
||||
port: 18080,
|
||||
wantTarget: "http://[fb00:cafe::1]:18080/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 without port stays single-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe::1]/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 with port is not double-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe::1]:8080/",
|
||||
},
|
||||
{
|
||||
name: "v4-mapped ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::ffff:10.0.0.1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::ffff:10.0.0.1]/",
|
||||
},
|
||||
{
|
||||
name: "full-form 8-group ipv6 without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1:0:0:0:0:3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -2487,6 +2487,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran
|
||||
allowedPeers[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Embedded proxy peers sit outside regular group membership but must
|
||||
// participate in any v6-enabled overlay to reach v6-only peers.
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get peers: %w", err)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.ProxyMeta.Embedded {
|
||||
allowedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return allowedPeers, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -762,16 +762,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
newPeer.IP = freeIP
|
||||
|
||||
if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil {
|
||||
var allGroupID string
|
||||
if !peer.ProxyMeta.Embedded {
|
||||
allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All")
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err)
|
||||
} else {
|
||||
// Embedded proxy peers are not group members but participate in any
|
||||
// IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers.
|
||||
allocate := peer.ProxyMeta.Embedded
|
||||
if !allocate {
|
||||
var allGroupID string
|
||||
if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil {
|
||||
allGroupID = allGroup.ID
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err)
|
||||
}
|
||||
allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID)
|
||||
}
|
||||
if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) {
|
||||
if allocate {
|
||||
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
connStr = filepath.Join(dataDir, filePath)
|
||||
}
|
||||
|
||||
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
||||
if hasQuery {
|
||||
connStr += "?" + query
|
||||
} else if runtime.GOOS != "windows" {
|
||||
// Compose query parameters. User-provided ?_busy_timeout (or its mattn alias
|
||||
// ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at
|
||||
// most that long on a lock instead of blocking the only Go-side connection.
|
||||
// mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so
|
||||
// the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared
|
||||
// stays the default on non-Windows for the same reason as before.
|
||||
parsed, _ := url.ParseQuery(query)
|
||||
var defaults []string
|
||||
if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" {
|
||||
defaults = append(defaults, "_busy_timeout=30000")
|
||||
}
|
||||
if !hasQuery && runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
connStr += "?cache=shared"
|
||||
defaults = append(defaults, "cache=shared")
|
||||
}
|
||||
parts := defaults
|
||||
if hasQuery {
|
||||
parts = append(parts, query)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
connStr += "?" + strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
||||
@@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies
|
||||
// that the _busy_timeout DSN parameter took effect at the driver level. Without
|
||||
// this, lock contention on the single SQLite connection waits indefinitely on
|
||||
// the Go side and can be hidden behind the 5-minute transactionTimeout.
|
||||
func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling")
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator
|
||||
// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE
|
||||
// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore.
|
||||
func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
envFile string
|
||||
expected int
|
||||
}{
|
||||
{name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000},
|
||||
{name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile)
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, tc.expected, busyTimeout)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,28 +598,21 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
return groupList
|
||||
}
|
||||
|
||||
// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups.
|
||||
// PeerIPv6Allowed reports whether the given peer participates in the IPv6 overlay.
|
||||
// Returns false if IPv6 is disabled or no groups are configured.
|
||||
func (a *Account) PeerIPv6Allowed(peerID string) bool {
|
||||
if len(a.Settings.IPv6EnabledGroups) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, groupID := range a.Settings.IPv6EnabledGroups {
|
||||
group, ok := a.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
_, ok := a.peerIPv6AllowedSet()[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group.
|
||||
// peerIPv6AllowedSet returns the set of peer IDs that participate in the IPv6 overlay:
|
||||
// members of any IPv6-enabled group, plus every embedded proxy peer (which sit outside
|
||||
// regular group membership but must reach v6-enabled peers).
|
||||
func (a *Account) peerIPv6AllowedSet() map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
if len(a.Settings.IPv6EnabledGroups) == 0 {
|
||||
return result
|
||||
}
|
||||
for _, groupID := range a.Settings.IPv6EnabledGroups {
|
||||
group, ok := a.Groups[groupID]
|
||||
if !ok {
|
||||
@@ -629,6 +622,11 @@ func (a *Account) peerIPv6AllowedSet() map[string]struct{} {
|
||||
result[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for id, p := range a.Peers {
|
||||
if p != nil && p.ProxyMeta.Embedded {
|
||||
result[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -92,9 +92,12 @@ func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
// GroupAllName is the reserved name of the default group that contains every peer in an account.
|
||||
const GroupAllName = "All"
|
||||
|
||||
// IsGroupAll checks if the group is a default "All" group.
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == "All"
|
||||
return g.Name == GroupAllName
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
|
||||
@@ -232,3 +232,33 @@ func TestIPv6RecalculationOnGroupChange(t *testing.T) {
|
||||
assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeerIPv6AllowedEmbeddedProxy(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1"},
|
||||
"proxy": {ID: "proxy", ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "netbird.test"}},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"group-devs": {ID: "group-devs", Peers: []string{"peer1"}},
|
||||
},
|
||||
Settings: &Settings{},
|
||||
}
|
||||
|
||||
t.Run("embedded proxy allowed when any v6 group exists, without group membership", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = []string{"group-devs"}
|
||||
assert.True(t, account.PeerIPv6Allowed("proxy"), "embedded proxy participates in v6 overlay")
|
||||
assert.True(t, account.PeerIPv6Allowed("peer1"), "regular peer in enabled group still allowed")
|
||||
})
|
||||
|
||||
t.Run("embedded proxy denied when no v6 group enabled", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = nil
|
||||
assert.False(t, account.PeerIPv6Allowed("proxy"), "v6 disabled account-wide denies embedded proxies too")
|
||||
})
|
||||
|
||||
t.Run("non-embedded peer outside any enabled group is not pulled in", func(t *testing.T) {
|
||||
account.Settings.IPv6EnabledGroups = []string{"group-devs"}
|
||||
account.Peers["lonely"] = &nbpeer.Peer{ID: "lonely"}
|
||||
assert.False(t, account.PeerIPv6Allowed("lonely"), "embedded-proxy bypass must not leak to regular peers")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user