Compare commits

...

18 Commits

Author SHA1 Message Date
crn4
06da6dce38 removed even setup keys and onboarding from light version of getaccount 2025-12-17 18:24:58 +01:00
crn4
7484c777e9 implement separate method for getaccount without users 2025-12-17 17:49:17 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
50 changed files with 1108 additions and 316 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy
```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -27,7 +27,11 @@ import (
)
const (
tableNat = "nat"
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Warnf("failed to load filter table, skipping accept rules: %v", err)
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
@@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
@@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
@@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables()
return r.acceptFilterRulesNftables(r.filterTable)
}
return r.acceptFilterRulesIptables(ipt)
@@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
@@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
@@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf,
},
}
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return r.conn.Flush()
}
func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptFilterRulesNftables()
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
if chain.Table.Name != table.Name {
continue
}
@@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
return fmt.Errorf("get rules: %v", err)
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
return chains
}
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
return nil
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done()
c.engineMutex.Lock()
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}

View File

@@ -56,6 +56,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)

View File

@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
@@ -292,21 +291,12 @@ func (e *Engine) Stop() error {
}
log.Info("Network monitor: stopped")
if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
log.Info("removing peers before dns")
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
if err := e.stopSSHServer(); err != nil {
log.Warnf("failed to stop SSH server: %v", err)
}
e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
@@ -314,33 +304,28 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.srWatcher != nil {
e.srWatcher.Close()
}
if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" {
log.Info("removing peers before routes")
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if err := e.removeAllPeers(); err != nil {
log.Errorf("failed to remove all peers: %s", err)
}
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
e.stopDNSForwarder()
if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
log.Info("removing peers after dns and routes")
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
// stop/restore DNS after peers are closed but before interface goes down
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
@@ -353,16 +338,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
if err := e.stateManager.Stop(stateCtx); err != nil {
log.Errorf("failed to stop state manager: %v", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -448,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
}
err := e.rpManager.Run()
if err != nil {
if err := e.rpManager.Run(); err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
}
}
@@ -501,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -766,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -1385,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
// mu protects cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK
import (
"context"
"fmt"
"net/netip"
"os"
"sort"
"strings"
"sync"
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error {
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
}
return netIDs
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection

View File

@@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
@@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}

View File

@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on JS/WASM
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")

View File

@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"sync"
"syscall"
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
return supported
}
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
// See https://bugs.debian.org/1078023 for details.
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
if runtime.GOOS != "linux" {
return false
}
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "login", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("login --version failed (likely shadow-utils): %v", err)
return false
}
isUtilLinux := strings.Contains(string(output), "util-linux")
log.Debugf("util-linux login detected: %v", isUtilLinux)
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}

View File

@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on Windows
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()

View File

@@ -138,7 +138,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
suSupportsPty bool
suSupportsPty bool
loginIsUtilLinux bool
}
type JWTConfig struct {
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
}
s.suSupportsPty = s.detectSuPtySupport(ctx)
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {

View File

@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
switch runtime.GOOS {
case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
return p, a, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default:
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
}
}
// fileExists checks if a file exists (helper for login command logic)
// getLinuxLoginCmd returns the login command for Linux systems.
// Handles differences between util-linux and shadow-utils login implementations.
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
// Special handling for Arch Linux without /etc/pam.d/remote
var loginArgs []string
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
loginArgs = []string{"-f", username, "-p"}
} else {
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
}
// util-linux login requires setsid -c to create a new session and set the
// controlling terminal. Without this, vhangup() kills the parent process.
// See https://bugs.debian.org/1078023 for details.
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
// to avoid external setsid dependency.
if !s.loginIsUtilLinux {
return loginPath, loginArgs
}
setsidPath, err := exec.LookPath("setsid")
if err != nil {
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
return loginPath, loginArgs
}
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
return setsidPath, args
}
// fileExists checks if a file exists
func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil

View File

@@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
addFields(entry)
return nil
}
@@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
return fmt.Sprintf("%s/%s", pkg, file)
}
func addHTTPFields(entry *logrus.Entry) {
func addFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
@@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) {
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -144,7 +144,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account, err = c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
@@ -300,7 +300,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountId)
if err != nil {
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
}
@@ -414,7 +414,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account, err = c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
@@ -506,7 +506,7 @@ func (c *Controller) recalculateNetworkMapCache(account *types.Account, validate
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountId)
if err != nil {
return err
}
@@ -548,7 +548,7 @@ func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountLightWithBackpressure)
if err != nil {
return nil
}
@@ -715,7 +715,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
for _, peerID := range peerIDs {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil {
return err
}
@@ -761,7 +761,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account, err := c.requestBuffer.GetAccountLightWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue

View File

@@ -10,9 +10,9 @@ import (
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -180,7 +180,7 @@ func unaryInterceptor(
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
@@ -194,7 +194,7 @@ func streamInterceptor(
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
reqID := xid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)

View File

@@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
@@ -194,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
@@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
}
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -525,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
reqStart := time.Now()
realIP := getRealIP(ctx)
sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -537,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -561,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -604,16 +592,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
@@ -730,12 +714,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -750,10 +732,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
// which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -813,10 +791,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
// which will be used by our clients to Login
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {

View File

@@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
}
}
@@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID)

View File

@@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to approve pending peers: %w", err)
}
if approvedCount > 0 {
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
updateAccountPeers = true
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
@@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -787,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
if ctx == nil {
ctx = context.Background()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err

View File

@@ -7,5 +7,6 @@ import (
)
type RequestBuffer interface {
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
// GetAccountLightWithBackpressure returns account without users, setup keys, and onboarding data with request buffering
GetAccountLightWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
}

View File

@@ -25,11 +25,13 @@ type AccountResult struct {
}
type AccountRequestBuffer struct {
store store.Store
getAccountRequests map[string][]*AccountRequest
mu sync.Mutex
getAccountRequestCh chan *AccountRequest
bufferInterval time.Duration
store store.Store
getAccountRequests map[string][]*AccountRequest
getAccountLightRequests map[string][]*AccountRequest
mu sync.Mutex
getAccountRequestCh chan *AccountRequest
getAccountLightRequestCh chan *AccountRequest
bufferInterval time.Duration
}
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
@@ -45,13 +47,16 @@ func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountReq
log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
ac := AccountRequestBuffer{
store: store,
getAccountRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest),
bufferInterval: bufferInterval,
store: store,
getAccountRequests: make(map[string][]*AccountRequest),
getAccountLightRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest),
getAccountLightRequestCh: make(chan *AccountRequest),
bufferInterval: bufferInterval,
}
go ac.processGetAccountRequests(ctx)
go ac.processGetAccountLightRequests(ctx)
return &ac
}
@@ -70,6 +75,22 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
return result.Account, result.Err
}
// GetAccountLightWithBackpressure returns account without users, setup keys, and onboarding data with request buffering
func (ac *AccountRequestBuffer) GetAccountLightWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),
}
log.WithContext(ctx).Tracef("requesting account light %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountLightRequestCh <- req
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account light with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountRequests[accountID]
@@ -109,3 +130,43 @@ func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
}
}
}
func (ac *AccountRequestBuffer) processGetAccountLightBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountLightRequests[accountID]
delete(ac.getAccountLightRequests, accountID)
ac.mu.Unlock()
if len(requests) == 0 {
return
}
startTime := time.Now()
account, err := ac.store.GetAccountLight(ctx, accountID)
log.WithContext(ctx).Tracef("getting account light %s in batch took %s", accountID, time.Since(startTime))
result := &AccountResult{Account: account, Err: err}
for _, req := range requests {
req.ResultChan <- result
close(req.ResultChan)
}
}
func (ac *AccountRequestBuffer) processGetAccountLightRequests(ctx context.Context) {
for {
select {
case req := <-ac.getAccountLightRequestCh:
ac.mu.Lock()
ac.getAccountLightRequests[req.AccountID] = append(ac.getAccountLightRequests[req.AccountID], req)
if len(ac.getAccountLightRequests[req.AccountID]) == 1 {
go func(ctx context.Context, accountID string) {
time.Sleep(ac.bufferInterval)
ac.processGetAccountLightBatch(ctx, accountID)
}(ctx, req.AccountID)
}
ac.mu.Unlock()
case <-ctx.Done():
return
}
}
}

View File

@@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
accountID := account.Id
userID := account.Users[account.CreatedBy].Id
ctx := context.Background()
newSettings := account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: true,
}
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
peer1.Status.RequiresApproval = true
peer2.Status.RequiresApproval = true
peer3.Status.RequiresApproval = false
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
newSettings = account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: false,
}
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range accountPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
}
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string

View File

@@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}

View File

@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil
}

View File

@@ -10,7 +10,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)

View File

@@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
}
}
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
@@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {

View File

@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}

View File

@@ -5,9 +5,6 @@ package settings
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
}
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {

View File

@@ -27,7 +27,6 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
Update("peer_status_requires_approval", false)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
}
return int(result.RowsAffected), nil
}
// SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 {
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -788,6 +796,14 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
return s.getAccountGorm(ctx, accountID)
}
// GetAccountLight returns account without users, setup keys, and onboarding data
func (s *SqlStore) GetAccountLight(ctx context.Context, accountID string) (*types.Account, error) {
if s.pool != nil {
return s.getAccountLightPgx(ctx, accountID)
}
return s.getAccountLightGorm(ctx, accountID)
}
func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
@@ -889,6 +905,82 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
return &account, nil
}
func (s *SqlStore) getAccountLightGorm(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccountLight for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Preload("Policies.Rules").
Preload("PeersG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
account.SetupKeys = make(map[string]*types.SetupKey)
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = &peer
}
account.PeersG = nil
account.Users = make(map[string]*types.User)
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.Peers = make([]string, len(group.GroupPeers))
for i, gp := range group.GroupPeers {
group.Peers[i] = gp.PeerID
}
if group.Resources == nil {
group.Resources = []types.Resource{}
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = &route
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
ns.AccountID = ""
if ns.NameServers == nil {
ns.NameServers = []nbdns.NameServer{}
}
if ns.Groups == nil {
ns.Groups = []string{}
}
if ns.Domains == nil {
ns.Domains = []string{}
}
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil
}
func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID)
if err != nil {
@@ -1172,6 +1264,221 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
return account, nil
}
func (s *SqlStore) getAccountLightPgx(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
errChan := make(chan error, 9)
wg.Add(1)
go func() {
defer wg.Done()
peers, err := s.getPeers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
groups, err := s.getGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
policies, err := s.getPolicies(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
routes, err := s.getRoutes(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
nsgs, err := s.getNameServerGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
checks, err := s.getPostureChecks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
networks, err := s.getNetworks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Networks = networks
}()
wg.Add(1)
go func() {
defer wg.Done()
routers, err := s.getNetworkRouters(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
resources, err := s.getNetworkResources(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(2)
errChan = make(chan error, 2)
var rules []*types.PolicyRule
go func() {
defer wg.Done()
var err error
rules, err = s.getPolicyRules(ctx, policyIDs)
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
var err error
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey)
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User)
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return account, nil
}
func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) {
var account types.Account
account.Network = &types.Network{}
@@ -2152,16 +2459,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@@ -2171,16 +2475,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -2229,11 +2530,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
@@ -2491,16 +2789,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey
result := tx.WithContext(ctx).
result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil {
@@ -2514,10 +2809,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
result := s.db.Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -2537,11 +2829,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string
_ = s.db.WithContext(ctx).Model(types.Group{}).
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
@@ -2569,9 +2858,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
@@ -2768,10 +3054,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -2897,10 +3180,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -4022,36 +4302,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil
}
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
@@ -4091,7 +4341,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
result := s.db.
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)

View File

@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
})
}
}
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
ctx := context.Background()
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
peers := []*nbpeer.Peer{
{
ID: "peer1",
AccountID: accountID,
DNSLabel: "peer1.netbird.cloud",
Key: "peer1-key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer2",
AccountID: accountID,
DNSLabel: "peer2.netbird.cloud",
Key: "peer2-key",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer3",
AccountID: accountID,
DNSLabel: "peer3.netbird.cloud",
Key: "peer3-key",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{
RequiresApproval: false,
LastSeen: time.Now().UTC(),
},
},
}
for _, peer := range peers {
err = store.AddPeerToAccount(ctx, peer)
require.NoError(t, err)
}
t.Run("approve all pending peers", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 2, count)
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range allPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
}
})
t.Run("no peers to approve", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 0, count)
})
t.Run("non-existent account", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, "non-existent")
require.NoError(t, err)
assert.Equal(t, 0, count)
})
})
}

View File

@@ -51,6 +51,8 @@ type Store interface {
GetAccountsCounter(ctx context.Context) (int64, error)
GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
// GetAccountLight returns account without users, setup keys, and onboarding data
GetAccountLight(ctx context.Context, accountID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
@@ -143,6 +145,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)

View File

@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -7,8 +7,8 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
//nolint
ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource)
reqID := uuid.New().String()
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx))
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else {

View File

@@ -0,0 +1,31 @@
package peerid
import (
"crypto/sha256"
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
"github.com/netbirdio/netbird/shared/relay/messages"
)
var (
// HealthCheckPeerID is the hashed peer ID for health check connections
HealthCheckPeerID = messages.HashID("healthcheck-agent")
// DummyAuthToken is a structurally valid auth token for health check.
// The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
DummyAuthToken = createDummyToken()
)
func createDummyToken() []byte {
token := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: make([]byte, sha256.Size),
Payload: []byte("healthcheck"),
}
return token.Marshal()
}
// IsHealthCheck checks if the given peer ID is the health check agent
func IsHealthCheck(peerID *messages.PeerID) bool {
return peerID != nil && *peerID == HealthCheckPeerID
}

View File

@@ -7,8 +7,10 @@ import (
"github.com/coder/websocket"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/shared/relay/messages"
)
func dialWS(ctx context.Context, address url.URL) error {
@@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error {
if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err)
}
defer func() {
_ = conn.CloseNow()
}()
authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
if err != nil {
return fmt.Errorf("failed to marshal auth message: %w", err)
}
if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
return fmt.Errorf("failed to write auth message: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil
}

View File

@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
return peerID, err
}
h.peerID = peerID
return peerID, nil
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
}
if err := h.validator.Validate(authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
}
return rawPeerID, nil

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
@@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
}
peerID, err := h.handshakeReceive()
if err != nil {
log.Errorf("failed to handshake: %s", err)
if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
} else {
log.Errorf("failed to handshake: %s", err)
}
if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}