Compare commits

...

18 Commits

Author SHA1 Message Date
crn4
06da6dce38 removed even setup keys and onboarding from light version of getaccount 2025-12-17 18:24:58 +01:00
crn4
7484c777e9 implement separate method for getaccount without users 2025-12-17 17:49:17 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
50 changed files with 1108 additions and 316 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy go mod tidy
``` ```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support ### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers. If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -27,7 +27,11 @@ import (
) )
const ( const (
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING" chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error var err error
r.filterTable, err = r.loadFilterTable() r.filterTable, err = r.loadFilterTable()
if err != nil { if err != nil {
log.Warnf("failed to load filter table, skipping accept rules: %v", err) log.Debugf("ip filter table not found: %v", err)
} }
return r, nil return r, nil
@@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound return nil, errFilterTableNotFound
} }
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error { func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw, Name: chainNameRoutingFw,
@@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface. // This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error { func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil return nil
} }
@@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
// filter table exists but iptables is not // iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptFilterRulesNftables() return r.acceptFilterRulesNftables(r.filterTable)
} }
return r.acceptFilterRulesIptables(ipt) return r.acceptFilterRulesIptables(ipt)
@@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else { } else {
log.Debugf("added iptables forward rule: %v", rule) log.Debugf("added iptables forward rule: %v", rule)
} }
@@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else { } else {
log.Debugf("added iptables input rule: %v", inputRule) log.Debugf("added iptables input rule: %v", inputRule)
} }
@@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
} }
func (r *router) acceptFilterRulesNftables() error { // acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{ iifRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf, Data: intf,
}, },
} }
oifRule := &nftables.Rule{ oifRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...), Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} }
r.conn.InsertRule(oifRule) r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{ inputRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule), UserData: []byte(userDataAcceptInputRule),
} }
r.conn.InsertRule(inputRule) r.conn.InsertRule(inputRule)
return r.conn.Flush()
} }
func (r *router) removeAcceptFilterRules() error { func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
return nil return nil
} }
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptFilterRulesNftables() return r.removeAcceptRulesFromTable(r.filterTable)
} }
return r.removeAcceptFilterRulesIptables(ipt) return r.removeAcceptFilterRulesIptables(ipt)
} }
func (r *router) removeAcceptFilterRulesNftables() error { func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil { if err != nil {
return fmt.Errorf("list chains: %v", err) return fmt.Errorf("list chains: %v", err)
} }
for _, chain := range chains { for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name { if chain.Table.Name != table.Name {
continue continue
} }
@@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue continue
} }
rules, err := r.conn.GetRules(r.filterTable, chain) if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
return fmt.Errorf("get rules: %v", err) log.Debugf("list chains for family %d: %v", family, err)
continue
} }
for _, rule := range rules { for _, chain := range allChains {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || if r.isExternalChain(chain) {
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || chains = append(chains, chain)
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
} }
} }
} }
if err := r.conn.Flush(); err != nil { return chains
return fmt.Errorf(flushError, err) }
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
} }
return nil // Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
} }
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
} }
} }
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock() c.engineMutex.Lock()
engine := c.engine
c.engine = nil c.engine = nil
c.engineMutex.Unlock() c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil { // todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil { if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
} }

View File

@@ -56,6 +56,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations). heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information. allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information. threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process Anonymization Process
@@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information. This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes Routes
The routes.txt file contains detailed routing table information in a tabular format: The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err) log.Errorf("failed to add profiles to debug bundle: %v", err)
} }
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil { if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err) return fmt.Errorf("add sync response: %w", err)
} }
@@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil return nil
} }
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error { func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces() interfaces, err := net.Interfaces()
if err != nil { if err != nil {

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips) f.cache.set(domain, question.Qtype, ips)

View File

@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
return nil return nil
} }
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil { if e.connMgr != nil {
e.connMgr.Close() e.connMgr.Close()
@@ -292,21 +291,12 @@ func (e *Engine) Stop() error {
} }
log.Info("Network monitor: stopped") log.Info("Network monitor: stopped")
if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
log.Info("removing peers before dns")
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
if err := e.stopSSHServer(); err != nil { if err := e.stopSSHServer(); err != nil {
log.Warnf("failed to stop SSH server: %v", err) log.Warnf("failed to stop SSH server: %v", err)
} }
e.cleanupSSHConfig() e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil { if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil { if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err) log.Warnf("failed to cleanup forward rules: %v", err)
@@ -314,33 +304,28 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
} }
e.stopDNSForwarder() if e.srWatcher != nil {
e.srWatcher.Close()
}
if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { log.Info("cleaning up status recorder states")
log.Info("removing peers before routes") e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
if err := e.removeAllPeers(); err != nil { e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
return fmt.Errorf("failed to remove all peers: %s", err) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
}
if err := e.removeAllPeers(); err != nil {
log.Errorf("failed to remove all peers: %s", err)
} }
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.srWatcher != nil { e.stopDNSForwarder()
e.srWatcher.Close()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { // stop/restore DNS after peers are closed but before interface goes down
log.Info("removing peers after dns and routes") // so dbus and friends don't complain because of a missing interface
if err := e.removeAllPeers(); err != nil { e.stopDNSServer()
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
@@ -353,16 +338,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close() e.flowManager.Close()
} }
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(stateCtx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) log.Errorf("failed to stop state manager: %v", err)
} }
if err := e.stateManager.PersistState(context.Background()); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout() timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -448,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil { if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err) return fmt.Errorf("create rosenpass manager: %w", err)
} }
err := e.rpManager.Run() if err := e.rpManager.Run(); err != nil {
if err != nil {
return fmt.Errorf("run rosenpass manager: %w", err) return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
@@ -501,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
} }
if err := e.createFirewall(); err != nil { if err := e.createFirewall(); err != nil {
e.close()
return err return err
} }
@@ -766,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.GetNetbirdConfig() != nil { if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig() wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns()) err := e.updateTURNs(wCfg.GetTurns())
@@ -1385,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig wgConfig WgConfig
initiator bool initiator bool
// mu protects updateWireGuardPeer and cancelFunc // mu protects cancelFunc
mu sync.Mutex mu sync.Mutex
cancelFunc func() cancelFunc func()
updateWg sync.WaitGroup updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done(): case <-ctx.Done():
return return
case <-t.C: case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
} }
e.mu.Unlock()
} }
} }

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"os"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error { func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client") log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName) log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
} }
return netIDs return netIDs
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import _ "golang.org/x/mobile/bind" import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection // RoutesSelectionInfoCollection made for Java layer to get non default types as collection

View File

@@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil { if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err) log.Errorf("failed to shut down properly: %v", err)
return nil, err return nil, err
} }
@@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
} }
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err) log.Errorf("failed to cleanup connection: %v", err)
return nil, err return nil, err
} }

View File

@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false return false
} }
// detectUtilLinuxLogin always returns false on JS/WASM
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM // executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM") logger.Errorf("PTY command execution not supported on JS/WASM")

View File

@@ -10,6 +10,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"runtime"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
return supported return supported
} }
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
// See https://bugs.debian.org/1078023 for details.
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
if runtime.GOOS != "linux" {
return false
}
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "login", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("login --version failed (likely shadow-utils): %v", err)
return false
}
isUtilLinux := strings.Contains(string(output), "util-linux")
log.Debugf("util-linux login detected: %v", isUtilLinux)
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching // createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su") suPath, err := exec.LookPath("su")
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false return false
} }
logger.Infof("starting interactive shell: %s", execCmd.Path) logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
} }

View File

@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false return false
} }
// detectUtilLinuxLogin always returns false on Windows
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand() command := session.RawCommand()

View File

@@ -138,7 +138,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig jwtConfig *JWTConfig
suSupportsPty bool suSupportsPty bool
loginIsUtilLinux bool
} }
type JWTConfig struct { type JWTConfig struct {
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
} }
s.suSupportsPty = s.detectSuPtySupport(ctx) s.suSupportsPty = s.detectSuPtySupport(ctx)
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
ln, addrDesc, err := s.createListener(ctx, addr) ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil { if err != nil {

View File

@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { return p, a, nil
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default: default:
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
} }
} }
// fileExists checks if a file exists (helper for login command logic) // getLinuxLoginCmd returns the login command for Linux systems.
// Handles differences between util-linux and shadow-utils login implementations.
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
// Special handling for Arch Linux without /etc/pam.d/remote
var loginArgs []string
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
loginArgs = []string{"-f", username, "-p"}
} else {
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
}
// util-linux login requires setsid -c to create a new session and set the
// controlling terminal. Without this, vhangup() kills the parent process.
// See https://bugs.debian.org/1078023 for details.
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
// to avoid external setsid dependency.
if !s.loginIsUtilLinux {
return loginPath, loginArgs
}
setsidPath, err := exec.LookPath("setsid")
if err != nil {
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
return loginPath, loginArgs
}
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
return setsidPath, args
}
// fileExists checks if a file exists
func (s *Server) fileExists(path string) bool { func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path) _, err := os.Stat(path)
return err == nil return err == nil

View File

@@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
entry.Data["context"] = source entry.Data["context"] = source
switch source { addFields(entry)
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil return nil
} }
@@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
return fmt.Sprintf("%s/%s", pkg, file) return fmt.Sprintf("%s/%s", pkg, file)
} }
func addHTTPFields(entry *logrus.Entry) { func addFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID entry.Data[context.RequestIDKey] = ctxReqID
} }
@@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) {
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID entry.Data[context.UserIDKey] = ctxInitiatorID
} }
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID entry.Data[context.PeerIDKey] = ctxDeviceID
} }

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -144,7 +144,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID) account = c.getAccountFromHolderOrInit(accountID)
} else { } else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err = c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account: %v", err) return fmt.Errorf("failed to get account: %v", err)
} }
@@ -300,7 +300,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
} }
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountId)
if err != nil { if err != nil {
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
} }
@@ -414,7 +414,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID) account = c.getAccountFromHolderOrInit(accountID)
} else { } else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err = c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
@@ -506,7 +506,7 @@ func (c *Controller) recalculateNetworkMapCache(account *types.Account, validate
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) { if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
@@ -548,7 +548,7 @@ func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account
if a != nil { if a != nil {
return a return a
} }
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountLightWithBackpressure)
if err != nil { if err != nil {
return nil return nil
} }
@@ -715,7 +715,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
for _, peerID := range peerIDs { for _, peerID := range peerIDs {
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -761,7 +761,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID) c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue continue

View File

@@ -10,9 +10,9 @@ import (
"slices" "slices"
"time" "time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@@ -180,7 +180,7 @@ func unaryInterceptor(
info *grpc.UnaryServerInfo, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler, handler grpc.UnaryHandler,
) (interface{}, error) { ) (interface{}, error) {
reqID := uuid.New().String() reqID := xid.New().String()
//nolint //nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource) ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint //nolint
@@ -194,7 +194,7 @@ func streamInterceptor(
info *grpc.StreamServerInfo, info *grpc.StreamServerInfo,
handler grpc.StreamHandler, handler grpc.StreamHandler,
) error { ) error {
reqID := uuid.New().String() reqID := xid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss) wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint //nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource) ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)

View File

@@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
} }
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation // todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil { if s.appMetrics != nil {
@@ -194,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
} }
if s.logBlockedPeers { if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
} }
if s.blockPeersWithSameConfig { if s.blockPeersWithSameConfig {
s.syncSem.Add(-1) s.syncSem.Add(-1)
@@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err return err
} }
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
} }
}() }()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID) s.secretsManager.CancelRefresh(peer.ID)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
} }
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -525,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
reqStart := time.Now() reqStart := time.Now()
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
sRealIP := realIP.String() sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{} loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq) peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -537,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
metahashed := metaHash(peerMeta, sRealIP) metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers { if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
} }
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -561,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
//nolint //nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() { defer func() {
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
} }
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}() }()
if loginReq.GetMeta() == nil { if loginReq.GetMeta() == nil {
@@ -604,16 +592,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err) return nil, mapError(ctx, err)
} }
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer") return nil, status.Errorf(codes.Internal, "failed logging in peer")
} }
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
key, err := s.secretsManager.GetWGKey() key, err := s.secretsManager.GetWGKey()
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
@@ -730,12 +714,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{ err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(), WgPubKey: key.PublicKey().String(),
Body: encryptedResp, Body: encryptedResp,
}) })
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -750,10 +732,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
@@ -813,10 +791,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {

View File

@@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1) relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel) go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
} }
} }
@@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for { for {
select { select {
case <-cancel: case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return return
case <-ticker.C: case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID) m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for { for {
select { select {
case <-cancel: case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return return
case <-ticker.C: case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID) m.pushNewRelayTokens(ctx, accountID, peerID)

View File

@@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err return err
} }
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err return err
} }
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to approve pending peers: %w", err)
}
if approvedCount > 0 {
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
updateAccountPeers = true
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange { if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err return err
@@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil return newSettings, nil
} }
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
} }
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -787,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
if ctx == nil {
ctx = context.Background()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -7,5 +7,6 @@ import (
) )
type RequestBuffer interface { type RequestBuffer interface {
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) // GetAccountLightWithBackpressure returns account without users, setup keys, and onboarding data with request buffering
GetAccountLightWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
} }

View File

@@ -25,11 +25,13 @@ type AccountResult struct {
} }
type AccountRequestBuffer struct { type AccountRequestBuffer struct {
store store.Store store store.Store
getAccountRequests map[string][]*AccountRequest getAccountRequests map[string][]*AccountRequest
mu sync.Mutex getAccountLightRequests map[string][]*AccountRequest
getAccountRequestCh chan *AccountRequest mu sync.Mutex
bufferInterval time.Duration getAccountRequestCh chan *AccountRequest
getAccountLightRequestCh chan *AccountRequest
bufferInterval time.Duration
} }
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
@@ -45,13 +47,16 @@ func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountReq
log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval) log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
ac := AccountRequestBuffer{ ac := AccountRequestBuffer{
store: store, store: store,
getAccountRequests: make(map[string][]*AccountRequest), getAccountRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest), getAccountLightRequests: make(map[string][]*AccountRequest),
bufferInterval: bufferInterval, getAccountRequestCh: make(chan *AccountRequest),
getAccountLightRequestCh: make(chan *AccountRequest),
bufferInterval: bufferInterval,
} }
go ac.processGetAccountRequests(ctx) go ac.processGetAccountRequests(ctx)
go ac.processGetAccountLightRequests(ctx)
return &ac return &ac
} }
@@ -70,6 +75,22 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
return result.Account, result.Err return result.Account, result.Err
} }
// GetAccountLightWithBackpressure returns account without users, setup keys, and onboarding data with request buffering
func (ac *AccountRequestBuffer) GetAccountLightWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),
}
log.WithContext(ctx).Tracef("requesting account light %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountLightRequestCh <- req
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account light with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
ac.mu.Lock() ac.mu.Lock()
requests := ac.getAccountRequests[accountID] requests := ac.getAccountRequests[accountID]
@@ -109,3 +130,43 @@ func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
} }
} }
} }
func (ac *AccountRequestBuffer) processGetAccountLightBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountLightRequests[accountID]
delete(ac.getAccountLightRequests, accountID)
ac.mu.Unlock()
if len(requests) == 0 {
return
}
startTime := time.Now()
account, err := ac.store.GetAccountLight(ctx, accountID)
log.WithContext(ctx).Tracef("getting account light %s in batch took %s", accountID, time.Since(startTime))
result := &AccountResult{Account: account, Err: err}
for _, req := range requests {
req.ResultChan <- result
close(req.ResultChan)
}
}
func (ac *AccountRequestBuffer) processGetAccountLightRequests(ctx context.Context) {
for {
select {
case req := <-ac.getAccountLightRequestCh:
ac.mu.Lock()
ac.getAccountLightRequests[req.AccountID] = append(ac.getAccountLightRequests[req.AccountID], req)
if len(ac.getAccountLightRequests[req.AccountID]) == 1 {
go func(ctx context.Context, accountID string) {
time.Sleep(ac.bufferInterval)
ac.processGetAccountLightBatch(ctx, accountID)
}(ctx, req.AccountID)
}
ac.mu.Unlock()
case <-ctx.Done():
return
}
}
}

View File

@@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
accountID := account.Id
userID := account.Users[account.CreatedBy].Id
ctx := context.Background()
newSettings := account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: true,
}
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
peer1.Status.RequiresApproval = true
peer2.Status.RequiresApproval = true
peer3.Status.RequiresApproval = false
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
newSettings = account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: false,
}
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range accountPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
}
}
func TestAccount_GetExpiredPeers(t *testing.T) { func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct { type test struct {
name string name string

View File

@@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
} }
if userAuth.AccountId != accountId { if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId userAuth.AccountId = accountId
} }

View File

@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil return nil
} }

View File

@@ -10,7 +10,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies // IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface { type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)

View File

@@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
} }
} }
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil { if err != nil {
@@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err return nil, nil, nil, err
} }
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil { if err != nil {
@@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err return nil, nil, nil, err
} }
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil { if err != nil {

View File

@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil { if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil return false, nil
} }
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil { if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil return false, nil
} }

View File

@@ -5,9 +5,6 @@ package settings
import ( import (
"context" "context"
"fmt" "fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings" "github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
} }
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator { if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil { if err != nil {

View File

@@ -27,7 +27,6 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took) s.metrics.StoreMetrics().CountPersistenceDuration(took)
} }
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err return err
} }
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil return nil
} }
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
Update("peer_status_requires_approval", false)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
}
return int(result.RowsAffected), nil
}
// SaveUsers saves the given list of users to the database. // SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 { if len(users) == 0 {
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var user types.User var user types.User
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
@@ -788,6 +796,14 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
return s.getAccountGorm(ctx, accountID) return s.getAccountGorm(ctx, accountID)
} }
// GetAccountLight returns account without users, setup keys, and onboarding data
func (s *SqlStore) GetAccountLight(ctx context.Context, accountID string) (*types.Account, error) {
if s.pool != nil {
return s.getAccountLightPgx(ctx, accountID)
}
return s.getAccountLightGorm(ctx, accountID)
}
func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -889,6 +905,82 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
return &account, nil return &account, nil
} }
func (s *SqlStore) getAccountLightGorm(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccountLight for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Preload("Policies.Rules").
Preload("PeersG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
account.SetupKeys = make(map[string]*types.SetupKey)
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = &peer
}
account.PeersG = nil
account.Users = make(map[string]*types.User)
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.Peers = make([]string, len(group.GroupPeers))
for i, gp := range group.GroupPeers {
group.Peers[i] = gp.PeerID
}
if group.Resources == nil {
group.Resources = []types.Resource{}
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = &route
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
ns.AccountID = ""
if ns.NameServers == nil {
ns.NameServers = []nbdns.NameServer{}
}
if ns.Groups == nil {
ns.Groups = []string{}
}
if ns.Domains == nil {
ns.Domains = []string{}
}
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil
}
func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID) account, err := s.getAccount(ctx, accountID)
if err != nil { if err != nil {
@@ -1172,6 +1264,221 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
return account, nil return account, nil
} }
func (s *SqlStore) getAccountLightPgx(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
errChan := make(chan error, 9)
wg.Add(1)
go func() {
defer wg.Done()
peers, err := s.getPeers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
groups, err := s.getGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
policies, err := s.getPolicies(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
routes, err := s.getRoutes(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
nsgs, err := s.getNameServerGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
checks, err := s.getPostureChecks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
networks, err := s.getNetworks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Networks = networks
}()
wg.Add(1)
go func() {
defer wg.Done()
routers, err := s.getNetworkRouters(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
resources, err := s.getNetworkResources(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(2)
errChan = make(chan error, 2)
var rules []*types.PolicyRule
go func() {
defer wg.Done()
var err error
rules, err = s.getPolicyRules(ctx, policyIDs)
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
var err error
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey)
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User)
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return account, nil
}
func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) {
var account types.Account var account types.Account
account.Network = &types.Network{} account.Network = &types.Network{}
@@ -2152,16 +2459,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
@@ -2171,16 +2475,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -2229,11 +2530,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User var user types.User
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
@@ -2491,16 +2789,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx.WithContext(ctx). result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key) Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil { if result.Error != nil {
@@ -2514,10 +2809,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
ctx, cancel := getDebuggingCtx(ctx) result := s.db.Model(&types.SetupKey{}).
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID). Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{ Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"), "used_times": gorm.Expr("used_times + 1"),
@@ -2537,11 +2829,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string var groupID string
_ = s.db.WithContext(ctx).Model(types.Group{}). _ = s.db.Model(types.Group{}).
Select("id"). Select("id").
Where("account_id = ? AND name = ?", accountID, "All"). Where("account_id = ? AND name = ?", accountID, "All").
Limit(1). Limit(1).
@@ -2569,9 +2858,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group // AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{ peer := &types.GroupPeer{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
@@ -2768,10 +3054,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
ctx, cancel := getDebuggingCtx(ctx) if err := s.db.Create(peer).Error; err != nil {
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
@@ -2897,10 +3180,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
ctx, cancel := getDebuggingCtx(ctx) result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store") return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -4022,36 +4302,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil return groupPeers, nil
} }
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}). result := s.db.Model(&types.Account{}).
@@ -4091,7 +4341,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet}, Network: &types.Network{Net: ipNet},
} }
result := s.db.WithContext(ctx). result := s.db.
Model(&types.Account{}). Model(&types.Account{}).
Where(idQueryCondition, accountID). Where(idQueryCondition, accountID).
Updates(&patch) Updates(&patch)

View File

@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
}) })
} }
} }
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
ctx := context.Background()
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
peers := []*nbpeer.Peer{
{
ID: "peer1",
AccountID: accountID,
DNSLabel: "peer1.netbird.cloud",
Key: "peer1-key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer2",
AccountID: accountID,
DNSLabel: "peer2.netbird.cloud",
Key: "peer2-key",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer3",
AccountID: accountID,
DNSLabel: "peer3.netbird.cloud",
Key: "peer3-key",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{
RequiresApproval: false,
LastSeen: time.Now().UTC(),
},
},
}
for _, peer := range peers {
err = store.AddPeerToAccount(ctx, peer)
require.NoError(t, err)
}
t.Run("approve all pending peers", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 2, count)
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range allPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
}
})
t.Run("no peers to approve", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 0, count)
})
t.Run("non-existent account", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, "non-existent")
require.NoError(t, err)
assert.Equal(t, 0, count)
})
})
}

View File

@@ -51,6 +51,8 @@ type Store interface {
GetAccountsCounter(ctx context.Context) (int64, error) GetAccountsCounter(ctx context.Context) (int64, error)
GetAllAccounts(ctx context.Context) []*types.Account GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error) GetAccount(ctx context.Context, accountID string) (*types.Account, error)
// GetAccountLight returns account without users, setup keys, and onboarding data
GetAccountLight(ctx context.Context, accountID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
@@ -143,6 +145,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)

View File

@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter meter metric.Meter
syncRequestsCounter metric.Int64Counter syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"), metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter, meter: meter,
syncRequestsCounter: syncRequestsCounter, syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter, loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests // CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
} }
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -7,8 +7,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
@@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
//nolint //nolint
ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource) ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource)
reqID := uuid.New().String() reqID := xid.New().String()
//nolint //nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx)) h.ServeHTTP(w, r.WithContext(ctx))
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 { if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else { } else {

View File

@@ -0,0 +1,31 @@
package peerid
import (
"crypto/sha256"
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
"github.com/netbirdio/netbird/shared/relay/messages"
)
var (
// HealthCheckPeerID is the hashed peer ID for health check connections
HealthCheckPeerID = messages.HashID("healthcheck-agent")
// DummyAuthToken is a structurally valid auth token for health check.
// The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
DummyAuthToken = createDummyToken()
)
func createDummyToken() []byte {
token := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: make([]byte, sha256.Size),
Payload: []byte("healthcheck"),
}
return token.Marshal()
}
// IsHealthCheck checks if the given peer ID is the health check agent
func IsHealthCheck(peerID *messages.PeerID) bool {
return peerID != nil && *peerID == HealthCheckPeerID
}

View File

@@ -7,8 +7,10 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/shared/relay/messages"
) )
func dialWS(ctx context.Context, address url.URL) error { func dialWS(ctx context.Context, address url.URL) error {
@@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err) return fmt.Errorf("failed to connect to websocket: %w", err)
} }
defer func() {
_ = conn.CloseNow()
}()
authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
if err != nil {
return fmt.Errorf("failed to marshal auth message: %w", err)
}
if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
return fmt.Errorf("failed to write auth message: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil return nil
} }

View File

@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
if err != nil { if err != nil {
return nil, err return peerID, err
} }
h.peerID = peerID h.peerID = peerID
return peerID, nil return peerID, nil
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
} }
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, nil return rawPeerID, nil

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
@@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
} }
peerID, err := h.handshakeReceive() peerID, err := h.handshakeReceive()
if err != nil { if err != nil {
log.Errorf("failed to handshake: %s", err) if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
} else {
log.Errorf("failed to handshake: %s", err)
}
if cErr := conn.Close(); cErr != nil { if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
} }