Compare commits

..

8 Commits

Author SHA1 Message Date
pascal
6b4d4076f4 extend tests 2026-05-04 15:16:59 +02:00
pascal
63d2217d8a Merge main into feature/affected-peers
Resolve conflicts keeping affected-peers logic while adopting
UpdateReason parameter from main for UpdateAccountPeers and
BufferUpdateAccountPeers signatures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-04 12:22:13 +02:00
pascal
0bfccd65d2 add to networks modules 2026-04-30 16:20:41 +02:00
pascal
26d778374b clean comments 2026-04-28 16:56:34 +02:00
pascal
5ec8bebfa5 add tests 2026-04-28 16:27:44 +02:00
pascal
cefb37e920 affected filtering on peers update 2026-04-28 13:48:01 +02:00
pascal
5a16c812fd use buffering affected peers 2026-04-27 18:18:29 +02:00
pascal
285bbc5ffb calculate affected peers 2026-04-27 17:49:12 +02:00
64 changed files with 3652 additions and 2123 deletions

View File

@@ -1,130 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Ideas & Feature Requests
Use this category for feature requests, enhancements, integrations, and product ideas.
NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request.
Please search first and add your use case to an existing discussion when one already exists.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar requests.
required: true
- label: I checked the documentation to confirm this is not already supported.
required: true
- label: This is a product idea or enhancement request, not a support question.
required: true
- label: I removed or anonymized sensitive details from examples and screenshots.
required: true
- type: dropdown
id: area
attributes:
label: Product area
description: Select every area this request touches.
multiple: true
options:
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- Management service / API
- Signal service
- Relay
- DNS
- Routes / Exit nodes
- NetBird SSH
- Access control policies
- Posture checks
- Identity provider / SSO
- Self-hosting / Deployment
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: textarea
id: problem
attributes:
label: Problem or use case
description: What are you trying to accomplish, and what is difficult or impossible today?
placeholder: |
As a ...
I want to ...
Because ...
validations:
required: true
- type: textarea
id: proposal
attributes:
label: Proposed solution
description: Describe the behavior, workflow, API, UI, or integration you would like to see.
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives or workarounds considered
description: What have you tried today? Why is the current workaround not enough?
- type: textarea
id: impact
attributes:
label: Community impact and priority
description: Help us understand who benefits and how urgent this is.
placeholder: |
- Number of users/teams/peers affected:
- Deployment type: Cloud / self-hosted / both
- Frequency: daily / weekly / occasional
- Blocking production adoption? yes/no
- Related comments, discussions, or customer requests:
validations:
required: true
- type: textarea
id: examples
attributes:
label: Examples from other tools or products
description: If another tool solves this well, link or describe the behavior.
- type: textarea
id: security
attributes:
label: Security, privacy, and compatibility considerations
description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns.
- type: textarea
id: implementation
attributes:
label: Implementation ideas
description: Optional. If you are familiar with the codebase or API, share possible implementation notes.
- type: dropdown
id: contribution
attributes:
label: Are you willing to help?
options:
- Yes, I can submit a PR if the approach is accepted.
- Yes, I can test or validate a proposed implementation.
- Yes, I can provide more use-case details.
- Not at this time.
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add screenshots, diagrams, links, or anything else that helps explain the request.

View File

@@ -1,237 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Issue Triage
Use this category for reproducible bugs and regressions in NetBird.
The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue.
Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there.
Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues, including closed ones, and checked the relevant docs.
required: true
- label: I believe this is a product bug rather than a configuration or setup question.
required: true
- label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: area
attributes:
label: Affected area
description: Select every area this report touches.
multiple: true
options:
- Client / Agent
- Reverse Proxy
- CLI
- Desktop UI
- Mobile app
- Peer connectivity
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay / Signal / NAT traversal
- Login / Authentication / IdP
- Dashboard / Admin UI
- Management service / API
- Access control policies / Posture checks
- Self-hosting / Deployment
- Kubernetes / Operator
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure / environment I do not fully control
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
description: Select every environment involved in the reproduction.
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: textarea
id: version
attributes:
label: NetBird version and upgrade status
description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why.
placeholder: |
Example:
- Client: 0.30.2
- Management: 0.30.2
- Signal: 0.30.2
- Relay: 0.30.2
- Dashboard: 0.30.2
- Upgrade status: reproduced on current version / cannot upgrade because ...
validations:
required: true
- type: dropdown
id: regression
attributes:
label: Did this work before?
options:
- Yes, this worked before
- No, this never worked
- Not sure
validations:
required: true
- type: textarea
id: regression-details
attributes:
label: Regression details
description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes.
placeholder: |
- Last known working version:
- First known broken version:
- Recent changes:
- type: textarea
id: summary
attributes:
label: Summary
description: Briefly describe the reproducible bug.
placeholder: What is broken?
validations:
required: true
- type: textarea
id: current-behavior
attributes:
label: Current behavior
description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible.
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected behavior
description: What did you expect to happen instead?
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Steps to reproduce
description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps.
placeholder: |
1. Configure ...
2. Run ...
3. Observe ...
For intermittent issues:
- Trigger:
- Frequency:
- Timing/timestamps:
validations:
required: true
- type: textarea
id: environment
attributes:
label: Environment and topology
description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate.
placeholder: |
- Peer A:
- Peer B:
- Same LAN or different networks:
- NAT/CGNAT/corporate firewall/mobile network:
- Other VPN software:
- Firewall, DNS, or endpoint security software:
- Routes, DNS, policies, posture checks, or SSH rules involved:
- IdP, reverse proxy, or browser involved:
validations:
required: true
- type: textarea
id: self-hosted-details
attributes:
label: Self-hosted details, if available
description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access.
placeholder: |
- Deployment method: quickstart / Docker Compose / Helm / operator / custom
- Management/signal/relay/dashboard versions:
- Reverse proxy:
- IdP/provider:
- STUN/TURN/coturn/relay details:
- Relevant component logs:
- type: textarea
id: logs
attributes:
label: Logs, status output, or debug evidence
description: |
For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days.
For UI, dashboard, or documentation reports, leave the pre-filled `N/A`.
value: "N/A"
render: shell
validations:
required: true
- type: textarea
id: related-reports
attributes:
label: Related issues or discussions
description: Optional. Link similar reports you found while searching, if any.
placeholder: |
- Related issue/discussion:
- Why this may be the same or different:
- type: textarea
id: impact
attributes:
label: Impact
description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround?
placeholder: |
- Affected users/peers:
- Business or production impact:
- Workaround available:
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation.

View File

@@ -1,146 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Q&A / Support
Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage.
This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public.
If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar questions.
required: true
- label: I reviewed the relevant NetBird documentation or troubleshooting guide.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: topic
attributes:
label: Topic
multiple: true
options:
- Getting started
- Self-hosting
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay
- Access control policies
- Posture checks
- Identity provider / SSO
- API
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: input
id: version
attributes:
label: NetBird version
description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant.
placeholder: "Example: client 0.30.2, management 0.30.2"
- type: textarea
id: question
attributes:
label: Question
description: What are you trying to understand or accomplish?
placeholder: Describe your question clearly.
validations:
required: true
- type: textarea
id: goal
attributes:
label: Desired outcome
description: What would a successful answer help you do?
placeholder: |
I want to configure ...
I expected ...
I need help deciding ...
- type: textarea
id: attempted
attributes:
label: What have you tried?
description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried.
placeholder: |
- Read ...
- Ran ...
- Changed ...
- Observed ...
- type: textarea
id: environment
attributes:
label: Relevant environment details
description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer.
placeholder: |
- Deployment:
- Components involved:
- Network/topology:
- Related config:
- type: textarea
id: logs
attributes:
label: Logs or output
description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant.
render: shell
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links, diagrams, screenshots, or other details that may help the community answer.

View File

@@ -0,0 +1,71 @@
---
name: Bug/Issue report
about: Create a report to help us improve
title: ''
labels: ['triage-needed']
assignees: ''
---
**Describe the problem**
A clear and concise description of what the problem is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Are you using NetBird Cloud?**
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
**NetBird version**
`netbird version`
**Is any other VPN software installed?**
If yes, which one?
**Debug output**
To help us resolve the problem, please attach the following anonymized status output
netbird status -dA
Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,26 +1,14 @@
blank_issues_enabled: false
blank_issues_enabled: true
contact_links:
- name: Start an Issue Triage discussion
url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage
about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue.
- name: Propose an idea or feature request
url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests
about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization.
- name: Ask a Q&A / Support question
url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support
about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage.
- name: Security vulnerability disclosure
url: https://github.com/netbirdio/netbird/security/policy
about: Please do not report security vulnerabilities in public issues or discussions.
- name: Community Support Forum
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum.
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact NetBird for Cloud support.
- name: Client / Connection Troubleshooting
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See the client troubleshooting guide for common connectivity issues.
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See the self-host troubleshooting guide for common deployment issues.
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ['feature-request']
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -1,128 +0,0 @@
name: Validated issue
description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work.
title: "[Validated]: "
body:
- type: markdown
attributes:
value: |
## Discussion-first issue policy
Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed.
Use this form when:
- A discussion has been validated and should become actionable work.
- A maintainer is opening internally validated work that can bypass the discussion-first flow.
Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions.
- type: checkboxes
id: validation-checks
attributes:
label: Validation checklist
options:
- label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer.
required: true
- label: The report has enough context for engineering to act on it without re-triaging from scratch.
required: true
- label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed.
required: true
- type: dropdown
id: issue-type
attributes:
label: Issue type
options:
- Bug / Regression
- Feature / Enhancement
- Documentation
- Maintenance / Refactor
- Cross-repository coordination
- Other
validations:
required: true
- type: input
id: source-discussion
attributes:
label: Source discussion
description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below.
placeholder: https://github.com/netbirdio/netbird/discussions/1234
validations:
required: true
- type: input
id: validation-owner
attributes:
label: Validation owner
description: GitHub handle of the DevRel team member or maintainer who validated this work.
placeholder: "@username"
validations:
required: true
- type: dropdown
id: target-repository
attributes:
label: Target repository
description: Where should the implementation work happen?
options:
- netbirdio/netbird
- netbirdio/dashboard
- netbirdio/kubernetes-operator
- netbirdio/docs
- Multiple repositories
- Unknown / needs routing
validations:
required: true
- type: textarea
id: summary
attributes:
label: Summary
description: Concise description of the validated work.
placeholder: What needs to be fixed, changed, documented, or built?
validations:
required: true
- type: textarea
id: evidence
attributes:
label: Validation evidence
description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes.
placeholder: |
- Reproduced by:
- Affected versions / platforms:
- Community signal:
- Related logs or screenshots:
validations:
required: true
- type: textarea
id: scope
attributes:
label: Proposed scope
description: Describe what is in scope and, if helpful, what is explicitly out of scope.
placeholder: |
In scope:
- ...
Out of scope:
- ...
validations:
required: true
- type: textarea
id: acceptance-criteria
attributes:
label: Acceptance criteria
description: What must be true for this issue to be closed?
placeholder: |
- [ ] ...
- [ ] ...
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes.

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -58,11 +58,6 @@ linters:
govet:
enable:
- nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false
revive:
rules:

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -24,7 +23,6 @@ import (
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil
}
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg)
}
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("")
if !noBrowser {

View File

@@ -1,25 +0,0 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

View File

@@ -1,26 +0,0 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -39,9 +39,6 @@ const (
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
)
@@ -51,7 +48,6 @@ var (
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
showQR bool
profileName string
configPath string
@@ -84,7 +80,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")

View File

@@ -29,27 +29,6 @@ import (
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
func SetCurrentMTU(mtu uint16) {
currentMTU = mtu
}
@@ -264,18 +243,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
var t time.Duration
var err error
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// Operate on a copy so the inbound request is unchanged: a client that
// did not advertise EDNS0 must not see an OPT in the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
}
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
@@ -283,7 +250,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
@@ -295,49 +262,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}
@@ -399,34 +330,6 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
}
return 0, false
}
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
return fmt.Sprintf("EDE %d", code)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {

View File

@@ -770,132 +770,3 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -8,7 +8,6 @@ import (
type mocListener struct {
lastState int
wg sync.WaitGroup
peersWg sync.WaitGroup
peers int
}
@@ -34,7 +33,6 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
}
func (l *mocListener) OnPeersListChanged(size int) {
l.peers = size
l.peersWg.Done()
}
func (l *mocListener) setWaiter() {
@@ -45,14 +43,6 @@ func (l *mocListener) wait() {
l.wg.Wait()
}
func (l *mocListener) setPeersWaiter() {
l.peersWg.Add(1)
}
func (l *mocListener) waitPeers() {
l.peersWg.Wait()
}
func Test_notifier_serverState(t *testing.T) {
type scenario struct {
@@ -82,13 +72,11 @@ func Test_notifier_serverState(t *testing.T) {
func Test_notifier_SetListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
listener.wait()
listener.waitPeers()
if listener.lastState != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
}
@@ -97,14 +85,9 @@ func Test_notifier_SetListener(t *testing.T) {
func Test_notifier_RemoveListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
// setListener replays cached state on a goroutine; wait for both the state
// and peers callbacks to finish so we don't race on listener.peers.
listener.wait()
listener.waitPeers()
n.removeListener()
n.peerListChanged(1)

View File

@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
// UpdatePeerState updates peer status
func (d *Status) UpdatePeerState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -343,29 +343,23 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
// when we close the connection we will not notify the router manager
notifyRouter := receivedState.ConnStatus == StatusIdle
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
// when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -377,20 +371,17 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
}
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers)
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -402,11 +393,8 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.routeIDLookup.RemoveRemoteRouteID(pref)
}
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers)
d.notifyPeerListChanged()
return nil
}
@@ -422,10 +410,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -443,28 +431,22 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil
}
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -479,28 +461,22 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil
}
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -514,28 +490,22 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil
}
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -552,18 +522,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil
}
@@ -630,33 +594,17 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification {
d.mux.Unlock()
return
}
d.peerListChangedForNotification = false
numPeers := d.numOfPeers()
d.notifyPeerListChanged()
// snapshot per-peer router state to deliver after the lock is released
type routerDispatch struct {
peerID string
snapshot map[string]RouterState
}
dispatches := make([]routerDispatch, 0, len(d.peers))
for key := range d.peers {
snapshot := d.snapshotRouterPeersLocked(key, true)
if snapshot != nil {
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
}
}
d.mux.Unlock()
d.notifier.peerListChanged(numPeers)
for _, rd := range dispatches {
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
d.notifyPeerStateChangeListeners(key)
}
}
@@ -707,12 +655,10 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()
d.localPeer = localPeerState
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
defer d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
d.localPeer = localPeerState
d.notifyAddressChanged()
}
// AddLocalPeerStateRoute adds a route to the local peer state
@@ -775,36 +721,30 @@ func (d *Status) CleanLocalPeerStateRoutes() {
// CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() {
d.mux.Lock()
d.localPeer = LocalPeerState{}
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
defer d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
d.localPeer = LocalPeerState{}
d.notifyAddressChanged()
}
// MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(err error) {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = false
d.managementError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected() {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = true
d.managementError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// UpdateSignalAddress update the address of the signal server
@@ -838,25 +778,21 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
// MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = false
d.signalError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected() {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = true
d.signalError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
@@ -1076,17 +1012,18 @@ func (d *Status) RemoveConnectionListener() {
d.notifier.removeListener()
}
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
// outside the lock so the channel send cannot stall any d.mux holder.
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
if !notify {
return nil
}
if _, ok := d.changeNotify[peerID]; !ok {
return nil
func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
subs, ok := d.changeNotify[peerID]
if !ok {
return
}
// collect the relevant data for router peers
routerPeers := make(map[string]RouterState, len(d.changeNotify))
for pid := range d.changeNotify {
s, ok := d.peers[pid]
@@ -1094,35 +1031,13 @@ func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[strin
log.Warnf("router peer not found in peers list: %s", pid)
continue
}
routerPeers[pid] = RouterState{
Status: s.ConnStatus,
Relayed: s.Relayed,
Latency: s.Latency,
}
}
return routerPeers
}
// dispatchRouterPeers delivers a previously snapshotted router-state map to
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
// fresh, short read of d.changeNotify under the lock to grab subscriber
// channels, then sends outside the lock so a slow consumer cannot block other
// d.mux holders. The send itself stays blocking (only short-circuited by the
// subscriber's context) so peer state transitions are not silently dropped.
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
if routerPeers == nil {
return
}
d.mux.Lock()
subsMap, ok := d.changeNotify[peerID]
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
if ok {
for _, sub := range subsMap {
subs = append(subs, sub)
}
}
d.mux.Unlock()
for _, sub := range subs {
select {
@@ -1132,6 +1047,14 @@ func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]Route
}
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
}
func (d *Status) numOfPeers() int {
return len(d.peers) + len(d.offlinePeers)
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
@@ -178,12 +177,7 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err
}
dst := net.IPv4zero
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
@@ -202,12 +196,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err
}
dst := net.IPv6zero
if runtime.GOOS == "linux" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}
_, gateway, localIP, err = router.Route(dst)
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}

View File

@@ -342,22 +342,6 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
// go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard
// client-side check. Substitute the lowest non-loopback address so the
// lookup falls through to the default route (::1 / 127.0.0.1 would match
// loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to
// the kernel and need no substitution.
if runtime.GOOS == "linux" && ip.IsUnspecified() {
if ip.Is6() {
// ::2
ip = netip.AddrFrom16([16]byte{15: 2})
} else {
// 0.0.0.1
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
}
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)

View File

@@ -354,13 +354,9 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
if testCase.prefix.Addr().Is6() && !testCase.expectError {
ensureIPv6DefaultRoute(t)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") {
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {

View File

@@ -1,30 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"bytes"
"os/exec"
"testing"
)
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("route already in table")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -1,41 +0,0 @@
//go:build linux && !android
package systemops
import (
"errors"
"net"
"syscall"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the
// loopback interface so route lookups for global IPv6 prefixes resolve in
// environments without v6 connectivity. Any pre-existing default route wins
// because of its lower metric.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
lo, err := netlink.LinkByName("lo")
require.NoError(t, err, "find loopback interface")
route := &netlink.Route{
Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)},
LinkIndex: lo.Attrs().Index,
Priority: 1 << 20,
}
if err := netlink.RouteAdd(route); err != nil {
if errors.Is(err, syscall.EEXIST) {
return
}
t.Skipf("install IPv6 fallback default route: %v", err)
}
t.Cleanup(func() {
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("delete IPv6 fallback default route: %v", err)
}
})
}

View File

@@ -1,34 +0,0 @@
//go:build windows
package systemops
import (
"bytes"
"os/exec"
"testing"
)
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop`
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("already exists")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop`
if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"runtime"
"time"
log "github.com/sirupsen/logrus"
@@ -27,10 +28,6 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
// Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop.
//
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
// to avoid the allocation churn of repeatedly dumping the kernel link
// table; on other platforms it falls back to a low-frequency poll.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done)
@@ -59,7 +56,31 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
return watchInterface(ctx, ifaceName, expectedIndex)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}
// getInterfaceIndex returns the index of a network interface by name.

View File

@@ -1,134 +0,0 @@
//go:build linux
package internal
import (
"context"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
// deletion or recreation of the WireGuard interface.
//
// The previous implementation polled net.InterfaceByName every 2 s, which
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
// entire kernel link table on every call. On hosts with many veth
// interfaces (containers, bridges) the resulting allocation churn was on
// the order of ~1 GB/day from this single ticker, which on small ARM
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
//
// The event-driven version below allocates only when the kernel actually
// publishes a link event for the tracked interface — typically zero
// allocations between events.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
done := make(chan struct{})
defer close(done)
// Buffer the channel to absorb event bursts (e.g. when many veth
// pairs are created/destroyed at once by container runtimes).
linkChan := make(chan netlink.LinkUpdate, 32)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
// Return shouldRestart=true so the engine recovers monitoring
// via triggerClientRestart instead of silently losing it for
// the rest of the process lifetime.
return true, fmt.Errorf("subscribe to link updates: %w", err)
}
// Race window: the interface could have been deleted (or recreated)
// between the initial getInterfaceIndex() in Start and LinkSubscribe
// completing its handshake with the kernel. Re-check explicitly so we
// do not block forever waiting for an event that already fired.
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
} else if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case update, ok := <-linkChan:
if !ok {
// The vishvananda/netlink subscription goroutine closes
// the channel on receive errors. Signal the engine to
// restart so monitoring is re-established instead of
// silently ending.
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
return true, fmt.Errorf("link subscription channel closed unexpectedly")
}
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
return true, err
}
}
}
}
// inspectLinkEvent classifies a single netlink link update against the
// tracked WireGuard interface. It returns (true, err) when the engine
// should restart monitoring; (false, nil) means the event is unrelated
// and the caller should keep waiting.
//
// The error component, when non-nil, describes the kernel-side reason
// (deletion or rename); the recreation case returns (true, nil) since
// no error condition is reported.
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
eventIndex := int(update.Index)
eventName := ""
if attrs := update.Attrs(); attrs != nil {
eventName = attrs.Name
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
case syscall.RTM_NEWLINK:
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
}
return false, nil
}
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
// tracked interface index.
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
if eventIndex != expectedIndex {
return false, nil
}
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted", ifaceName)
}
// inspectNewLink reports a restart when an RTM_NEWLINK either:
//
// 1. Introduces a link with our name at a different index (recreation
// after a delete), or
//
// 2. Reports a link still at our index but with a different name
// (in-place rename). The previous polling implementation caught
// this implicitly because net.InterfaceByName(ifaceName) would
// start failing; the event-driven version has to test it.
//
// Same name + same index is just a flag/state change on the existing
// interface and is ignored.
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
if eventName == ifaceName && eventIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, eventIndex)
return true, nil
}
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
ifaceName, eventName, expectedIndex)
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
}
return false, nil
}

View File

@@ -1,56 +0,0 @@
//go:build !linux
package internal
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// watchInterface polls net.InterfaceByName at a fixed interval to detect
// deletion or recreation of the WireGuard interface.
//
// This is the fallback used on non-Linux desktop and server platforms
// (darwin, windows, freebsd). It is also compiled on android and ios so
// the package builds on every supported GOOS, but it is never reached
// at runtime there because Start() in wg_iface_monitor.go exits early
// on mobile platforms.
//
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
// RTNLGRP_LINK netlink subscription instead, because on Linux
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
// dumps the entire kernel link table on every call and produces
// significant allocation churn (netbirdio/netbird#3678).
//
// Windows is also reported in #3678 as affected by RSS climb. A future
// follow-up could implement an event-driven watcher there using
// NotifyIpInterfaceChange from iphlpapi.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}

View File

@@ -224,20 +224,15 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
}
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil
}

View File

@@ -13,9 +13,11 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto"
@@ -299,11 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
}, ctx)
}
// isContextDone reports whether the local context has been canceled or has
// exceeded its deadline. It deliberately does not inspect gRPC status codes:
// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not
// short-circuit our retry loop, since retrying is the correct response when
// the local context is still alive.
func isContextDone(err error) bool {
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
if s, ok := status.FromError(err); ok {
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
}
return false
}

22
go.mod
View File

@@ -17,8 +17,8 @@ require (
github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0
golang.org/x/crypto v0.49.0
golang.org/x/sys v0.42.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -68,11 +68,10 @@ require (
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.4.0
github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.72
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
@@ -118,11 +117,11 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.34.0
golang.org/x/net v0.53.0
golang.org/x/mod v0.33.0
golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.42.0
golang.org/x/term v0.41.0
golang.org/x/time v0.15.0
google.golang.org/api v0.276.0
gopkg.in/yaml.v3 v3.0.1
@@ -303,14 +302,13 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
rsc.io/qr v0.2.0 // indirect
)
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
@@ -323,6 +321,8 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

40
go.sum
View File

@@ -395,8 +395,6 @@ github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/libp2p/go-netroute v0.4.0 h1:sZZx9hyANYUx9PZyqcgE/E1GUG3iEtTZHUEvdtXT7/Q=
github.com/libp2p/go-netroute v0.4.0/go.mod h1:Nkd5ShYgSMS5MUKy/MU2T57xFoOKvvLR92Lic48LEyA=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
@@ -417,12 +415,10 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@@ -455,6 +451,8 @@ github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
@@ -711,8 +709,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -729,8 +727,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -749,8 +747,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
@@ -801,8 +799,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -815,8 +813,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -828,8 +826,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -843,8 +841,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -917,5 +915,3 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -45,6 +45,7 @@ type Controller struct {
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -63,6 +64,13 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -264,6 +272,133 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
hasConnected := false
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
hasConnected = true
break
}
}
if !hasConnected {
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
var peersToUpdate []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
peersToUpdate = append(peersToUpdate, peer)
}
}
if len(peersToUpdate) == 0 {
return nil
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -372,6 +507,96 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(ctx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(ctx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -569,21 +794,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
}
return nil
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -616,7 +844,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping network map update", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,6 +19,8 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
@@ -27,9 +29,9 @@ type Controller interface {
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -158,45 +172,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
}
// StartWarmup mocks base method.
@@ -250,3 +264,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -11,9 +11,9 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, p *Proxy) error
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -13,8 +13,7 @@ import (
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
@@ -44,7 +43,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -52,7 +51,6 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
}
p := &proxy.Proxy{
ID: proxyID,
SessionID: sessionID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
@@ -63,42 +61,48 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"sessionID": sessionID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return p, nil
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"sessionID": sessionID,
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
m.metrics.IncrementProxyHeartbeatCount()
return nil
}

View File

@@ -93,32 +93,31 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
@@ -152,17 +151,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// MockController is a mock of Controller interface.

View File

@@ -18,13 +18,12 @@ type Capabilities struct {
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -16,7 +16,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
@@ -90,7 +89,6 @@ const pkceVerifierTTL = 10 * time.Minute
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
sessionID string
address string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
@@ -168,22 +166,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
capabilities: req.GetCapabilities(),
stream: stream,
@@ -203,13 +188,12 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn)
s.connectedProxies.Delete(proxyID)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
@@ -218,27 +202,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
log.Infof("Proxy %s disconnected", proxyID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
@@ -248,31 +227,29 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
select {
case err := <-errChan:
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err()
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return
}
}

View File

@@ -2215,7 +2215,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}

View File

@@ -125,6 +125,8 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -122,6 +122,18 @@ func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reas
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs)
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockManagerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs)
}
// BuildUserInfosForAccount mocks base method.
func (m *MockManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
m.ctrl.T.Helper()
@@ -1608,6 +1620,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

File diff suppressed because it is too large Load Diff

View File

@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
allGroups := slices.Concat(addedGroups, removedGroups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -85,8 +83,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -133,20 +131,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
// Resolve before removing, so the peer being removed is still included
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
@@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -840,18 +821,177 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
return nil, nil
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
return true, nil
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
} else {
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
continue
}
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups", policy.ID, policy.Name)
for _, gID := range policy.RuleGroups() {
groupSet[gID] = struct{}{}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID)
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID)
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
}
return false, nil
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
} else {
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.PeerGroups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.AccessControlGroups {
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer)
peerSet[r.Peer] = struct{}{}
}
}
}
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
} else {
for _, ns := range nsGroups {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
break
}
}
}
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
} else {
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
groupSet[gID] = struct{}{}
}
}
}
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
} else {
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
peerSet[router.Peer] = struct{}{}
}
}
}
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, gID := range rule.Sources {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range rule.Destinations {
if _, ok := groupSet[gID]; ok {
return true
}
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
for _, gID := range r.Groups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.AccessControlGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
for _, gID := range router.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}

View File

@@ -129,6 +129,8 @@ type MockAccountManager struct {
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
@@ -206,6 +208,18 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
}
}
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
if am.UpdateAffectedPeersFunc != nil {
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
}
}
func (am *MockAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
if am.BufferUpdateAffectedPeersFunc != nil {
am.BufferUpdateAffectedPeersFunc(ctx, accountID, peerIDs)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"unicode/utf8"
@@ -57,22 +58,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -81,8 +79,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return newNSGroup.Copy(), nil
@@ -102,7 +100,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
@@ -115,15 +113,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -132,8 +128,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -150,7 +146,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
@@ -158,10 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
return err
@@ -175,8 +168,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -224,24 +217,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -15,7 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
nbTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -112,6 +113,14 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return network, m.store.SaveNetwork(ctx, network)
}
// networkAffectedPeersData holds data loaded inside the transaction for affected peer resolution.
type networkAffectedPeersData struct {
resourceGroupIDs []string
routerPeerGroups []string
directPeerIDs []string
policies []*nbTypes.Policy
}
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -127,13 +136,22 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
}
var eventsToStore []func()
var affectedData *networkAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get resources in network: %w", err)
}
var resourceGroupIDs []string
for _, resource := range resources {
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err == nil {
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
}
event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
@@ -141,12 +159,19 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event...)
}
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get routers in network: %w", err)
}
for _, router := range routers {
var routerPeerGroups []string
var directPeerIDs []string
for _, router := range netRouters {
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
if router.Peer != "" {
directPeerIDs = append(directPeerIDs, router.Peer)
}
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
if err != nil {
return fmt.Errorf("failed to delete router: %w", err)
@@ -154,6 +179,24 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event)
}
// load policies before deleting so group memberships are still present
var policies []*nbTypes.Policy
if len(resourceGroupIDs) > 0 {
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for affected peers: %v", err)
}
}
if len(resourceGroupIDs) > 0 || len(routerPeerGroups) > 0 || len(directPeerIDs) > 0 {
affectedData = &networkAffectedPeersData{
resourceGroupIDs: resourceGroupIDs,
routerPeerGroups: routerPeerGroups,
directPeerIDs: directPeerIDs,
policies: policies,
}
}
err = transaction.DeleteNetwork(ctx, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to delete network: %w", err)
@@ -178,11 +221,82 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
if affectedData != nil {
affectedPeerIDs := resolveNetworkAffectedPeers(ctx, m.store, accountID, affectedData)
if len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
}
return nil
}
// resolveNetworkAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func resolveNetworkAffectedPeers(ctx context.Context, s store.Store, accountID string, data *networkAffectedPeersData) []string {
groupSet := make(map[string]struct{})
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
if len(data.resourceGroupIDs) > 0 {
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
groupSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
for _, srcGID := range rule.Sources {
groupSet[srcGID] = struct{}{}
}
break
}
}
}
}
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(data.directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range data.directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -114,6 +114,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
}
var eventsToStore []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
@@ -152,6 +153,11 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
})
if err != nil {
@@ -162,7 +168,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
}
return resource, nil
}
@@ -207,6 +215,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Prefix = prefix
var eventsToStore []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
@@ -232,6 +241,15 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network resource: %w", err)
}
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, oldResource.AccountID, oldResource.ID)
if err != nil {
return fmt.Errorf("failed to get old resource groups: %w", err)
}
var oldGroupIDs []string
for _, g := range oldGroups {
oldGroupIDs = append(oldGroupIDs, g.ID)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
@@ -247,6 +265,11 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
})
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, append(resource.GroupIDs, oldGroupIDs...))
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
@@ -270,7 +293,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
}
return resource, nil
}
@@ -331,7 +356,22 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
}
var events []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to get resource groups: %w", err)
}
var resourceGroupIDs []string
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, accountID, networkID, resourceGroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
@@ -352,7 +392,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
}
@@ -399,6 +441,127 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
return eventsToStore, nil
}
// resourceAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
type resourceAffectedPeersData struct {
resourceGroupIDs []string
policies []*nbtypes.Policy
routerPeerGroups []string
routerDirectPeers []string
}
// loadResourceAffectedPeersData loads the data needed to determine affected peers within a transaction.
func loadResourceAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, resourceGroupIDs []string) (*resourceAffectedPeersData, error) {
if len(resourceGroupIDs) == 0 {
return nil, nil
}
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get policies: %w", err)
}
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get routers: %w", err)
}
var routerPeerGroups []string
var routerDirectPeers []string
for _, router := range routers {
if !router.Enabled {
continue
}
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
if router.Peer != "" {
routerDirectPeers = append(routerDirectPeers, router.Peer)
}
}
return &resourceAffectedPeersData{
resourceGroupIDs: resourceGroupIDs,
policies: policies,
routerPeerGroups: routerPeerGroups,
routerDirectPeers: routerDirectPeers,
}, nil
}
// resolveResourceAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountID string, data *resourceAffectedPeersData) []string {
if data == nil {
return nil
}
groupSet := make(map[string]struct{})
var directPeerIDs []string
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
referencesResource := false
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
referencesResource = true
break
}
}
if !referencesResource {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
}
}
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
directPeerIDs = append(directPeerIDs, data.routerDirectPeers...)
if len(groupSet) == 0 && len(directPeerIDs) == 0 {
return nil
}
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -15,7 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -90,6 +91,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
@@ -112,6 +114,11 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, router.PeerGroups, router.Peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
})
if err != nil {
@@ -120,7 +127,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
}
return router, nil
}
@@ -156,6 +165,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
@@ -166,6 +176,16 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
allPeerGroups := router.PeerGroups
directPeers := []string{router.Peer}
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
if err == nil {
allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...)
if oldRouter.Peer != "" {
directPeers = append(directPeers, oldRouter.Peer)
}
}
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
@@ -176,6 +196,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
})
if err != nil {
@@ -184,7 +209,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
}
return router, nil
}
@@ -199,7 +226,19 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
var event func()
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil {
return fmt.Errorf("failed to get router: %w", err)
}
// load before delete so group memberships are still present
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, accountID, networkID, router.PeerGroups, router.Peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
if err != nil {
return fmt.Errorf("failed to delete network router: %w", err)
@@ -218,7 +257,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
}
@@ -250,6 +291,136 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
return event, nil
}
// routerAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
type routerAffectedPeersData struct {
routerPeerGroups []string
directPeerIDs []string
resourceGroupIDs []string
policies []*nbtypes.Policy
}
// loadRouterAffectedPeersData loads the data needed to determine affected peers within a transaction.
func loadRouterAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, routerPeerGroups []string, directPeers ...string) (*routerAffectedPeersData, error) {
var directPeerIDs []string
for _, p := range directPeers {
if p != "" {
directPeerIDs = append(directPeerIDs, p)
}
}
if len(routerPeerGroups) == 0 && len(directPeerIDs) == 0 {
return nil, nil
}
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err)
}
var resourceGroupIDs []string
for _, resource := range resources {
if !resource.Enabled {
continue
}
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil {
return nil, fmt.Errorf("failed to get groups for resource %s: %w", resource.ID, err)
}
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
}
var policies []*nbtypes.Policy
if len(resourceGroupIDs) > 0 {
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get policies: %w", err)
}
}
return &routerAffectedPeersData{
routerPeerGroups: routerPeerGroups,
directPeerIDs: directPeerIDs,
resourceGroupIDs: resourceGroupIDs,
policies: policies,
}, nil
}
// resolveRouterAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func (m *managerImpl) resolveRouterAffectedPeers(ctx context.Context, accountID string, data *routerAffectedPeersData) []string {
if data == nil {
return nil
}
groupSet := make(map[string]struct{})
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
if len(data.resourceGroupIDs) > 0 {
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
referencesResource := false
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
referencesResource = true
break
}
}
if !referencesResource {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(data.directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range data.directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -110,7 +110,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -294,7 +296,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -452,6 +456,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var peer *nbpeer.Peer
var settings *types.Settings
var eventsToStore []func()
var affectedPeerIDs []string
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
@@ -476,6 +481,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID})
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -499,7 +506,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
@@ -826,7 +833,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
@@ -909,7 +918,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1036,7 +1047,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1228,6 +1241,62 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID)
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolvePeerIDs resolves group IDs and direct peer IDs into a deduplicated peer ID list.
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers", groupIDs, len(peerIDs))
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers", groupIDs, directPeerIDs, len(peerIDs))
return peerIDs
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
_ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of affected peer IDs.
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err)
return nil
}
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs)
result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs)
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))
return result
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}

View File

@@ -45,12 +45,13 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var existingPolicy *types.Policy
var action = activity.PolicyAdded
var unchanged bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
if err != nil {
return err
}
@@ -64,25 +65,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
action = activity.PolicyUpdated
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy, existingPolicy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -95,12 +89,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
policyOp := types.UpdateOperationCreate
if isUpdate {
policyOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return policy, nil
@@ -117,7 +107,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
}
var policy *types.Policy
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
@@ -125,10 +115,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
return err
@@ -142,8 +130,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -162,44 +150,23 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
// collectPolicyAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given policies.
func collectPolicyAffectedGroupsAndPeers(policies ...*types.Policy) (groupIDs []string, directPeerIDs []string) {
for _, policy := range policies {
if policy == nil {
continue
}
groupIDs = append(groupIDs, policy.RuleGroups()...)
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.DestinationResource.ID)
}
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
return
}
// validatePolicy validates the policy and its rules. For updates it returns

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -41,9 +40,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
@@ -51,12 +50,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
action = activity.PostureCheckUpdated
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
}
postureChecks.AccountID = accountID
@@ -76,12 +73,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
postureOp := types.UpdateOperationCreate
if isUpdate {
postureOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return postureChecks, nil
@@ -137,27 +130,22 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
// collectPostureCheckAffectedGroupsAndPeers returns group IDs and peer IDs from policies referencing the posture check.
func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
return nil, nil
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(policy)
groupIDs = append(groupIDs, gIDs...)
directPeerIDs = append(directPeerIDs, pIDs...)
}
}
return false, nil
return groupIDs, directPeerIDs
}
// validatePostureChecks validates the posture checks.

View File

@@ -503,21 +503,20 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result)
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
t.Run("posture check does not exist", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result)
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, "unknown")
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
@@ -526,9 +525,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
@@ -537,9 +535,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
@@ -547,9 +544,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
// The collector returns groups even if they have no peers — the groups are still referenced
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
@@ -558,8 +555,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
// Non-existent groups are filtered out during SavePolicy validation,
// so the saved policy has empty Sources/Destinations
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
}

View File

@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
var newRoute *route.Route
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
newRoute = &route.Route{
@@ -173,15 +173,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(newRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -190,8 +188,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return newRoute, nil
@@ -208,8 +206,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
@@ -221,21 +218,15 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(routeToSave, oldRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -244,8 +235,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -261,19 +252,17 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
var rt *route.Route
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(rt)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
return err
@@ -285,10 +274,10 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -377,23 +366,20 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
// collectRouteAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given routes.
func collectRouteAffectedGroupsAndPeers(routes ...*route.Route) (groupIDs []string, directPeerIDs []string) {
for _, r := range routes {
if r == nil {
continue
}
groupIDs = append(groupIDs, r.Groups...)
groupIDs = append(groupIDs, r.PeerGroups...)
groupIDs = append(groupIDs, r.AccessControlGroups...)
if r.Peer != "" {
directPeerIDs = append(directPeerIDs, r.Peer)
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
return
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix

View File

@@ -4662,6 +4662,40 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
return peers, nil
}
func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
if len(groupIDs) == 0 {
return nil, nil
}
var peerIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs).
Pluck("peer_id", &peerIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error)
}
return peerIDs, nil
}
func (s *SqlStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
if len(peerIDs) == 0 {
return nil, nil
}
var groupIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT group_id").
Where("account_id = ? AND peer_id IN ?", accountID, peerIDs).
Pluck("group_id", &groupIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get group IDs by peers: %s", result.Error)
}
return groupIDs, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
@@ -5437,35 +5471,13 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
return nil
}
// DisconnectProxy marks a proxy as disconnected only if the session ID matches.
// This prevents a slow-to-close old session from overwriting a newer reconnection.
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
now := time.Now()
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", proxyID, sessionID).
Updates(map[string]any{
"status": "disconnected",
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now)
if result.Error != nil {
@@ -5474,11 +5486,17 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
}
if result.RowsAffected == 0 {
p.LastSeen = now
p.ConnectedAt = &now
p.Status = "connected"
if err := s.db.Create(p).Error; err != nil {
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
}
}

View File

@@ -159,6 +159,8 @@ type Store interface {
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
@@ -284,8 +286,7 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -1853,6 +1853,36 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -2800,20 +2830,6 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()
@@ -3010,17 +3026,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// UpdateService mocks base method.

View File

@@ -1155,7 +1155,8 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, peerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1271,6 +1272,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var userPeers []*nbpeer.Peer
var targetUser *types.User
var settings *types.Settings
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1291,6 +1293,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
if len(userPeers) > 0 {
updateAccountPeers = true
var peerIDs []string
for _, peer := range userPeers {
peerIDs = append(peerIDs, peer.ID)
}
// Resolve before delete so group memberships are still present.
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, peerIDs)
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings)
if err != nil {
return fmt.Errorf("failed to delete user peers: %w", err)
@@ -1314,7 +1324,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err)
}
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}

View File

@@ -846,7 +846,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
@@ -962,7 +962,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
@@ -2022,7 +2022,7 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()

View File

@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil
}
@@ -656,72 +656,3 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
}
}
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
// when a proxy reconnects before the old stream's cleanup runs, the new
// connection is NOT removed by the stale defer.
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-race"
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithCancel(context.Background())
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream1.Recv()
require.NoError(t, err)
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should be registered after first connection")
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream2.Recv()
require.NoError(t, err)
}
cancel1()
time.Sleep(200 * time.Millisecond)
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "rp-1",
AccountId: "test-account-1",
Domain: "app1.test.proxy.io",
}},
})
msg, err := stream2.Recv()
require.NoError(t, err, "new stream should still receive updates")
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
}

View File

@@ -246,23 +246,27 @@ func (c *GrpcClient) handleJobStream(
for {
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
if err != nil {
if ctx.Err() != nil {
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
return nil
}
if s, ok := gstatus.FromError(err); ok {
switch s.Code() {
case codes.PermissionDenied:
c.notifyDisconnected(err)
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
return err
case codes.Unimplemented:
log.Warn("Job feature is not supported by the current management server version. " +
"Please update the management service to use this feature.")
return nil
default:
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
return err
}
} else {
// non-gRPC error
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
return err
}
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
return err
}
if jobReq == nil || len(jobReq.ID) == 0 {
@@ -377,15 +381,22 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
if ctx.Err() != nil {
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
if s, ok := gstatus.FromError(err); ok {
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
} else {
// non-gRPC error
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
return nil

View File

@@ -167,7 +167,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream)
if err != nil {
if ctx.Err() != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
return nil
}

View File

@@ -10,13 +10,15 @@ import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/mdlayher/socket"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
@@ -35,6 +37,8 @@ type SharedSocket struct {
conn6 *socket.Conn
port int
mtu uint16
routerMux sync.RWMutex
router routing.Router
packetDemux chan rcvdPacket
cancel context.CancelFunc
}
@@ -78,6 +82,11 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
}
}()
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
@@ -118,26 +127,31 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
go rawSock.read(rawSock.conn6.Recvfrom)
}
go rawSock.updateRouter()
return rawSock, nil
}
// resolveSrc returns the source IP the kernel will pick for a packet sent to
// dst by these raw sockets, mirroring the fwmark the kernel will see on send.
func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) {
opts := &netlink.RouteGetOptions{}
if nbnet.AdvancedRouting() {
opts.Mark = nbnet.ControlPlaneMark
}
routes, err := netlink.RouteGetWithOptions(dst, opts)
if err != nil {
return nil, fmt.Errorf("route get %s: %w", dst, err)
}
for _, r := range routes {
if r.Src != nil {
return r.Src, nil
// updateRouter updates the listener routing table client
// this is needed to avoid outdated information across different client networks
func (s *SharedSocket) updateRouter() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
router, err := netroute.New()
if err != nil {
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
continue
}
s.routerMux.Lock()
s.router = router
s.routerMux.Unlock()
}
}
return nil, fmt.Errorf("no source IP for %s", dst)
}
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
@@ -296,15 +310,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
DstPort: layers.UDPPort(rUDPAddr.Port),
}
src, err := s.resolveSrc(rUDPAddr.IP)
s.routerMux.RLock()
defer s.routerMux.RUnlock()
_, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil {
return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err)
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
}
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
if conn == nil {
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
}
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)

View File

@@ -10,7 +10,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
)
// TextWriter writes human-readable one-line-per-packet summaries.
@@ -593,45 +592,19 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
anCount := d.ANCount
nsCount := d.NSCount
arCount := d.ARCount
ede := formatEDE(d)
if d.ResponseCode != layers.DNSResponseCodeNoErr {
return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
}
if anCount > 0 && len(d.Answers) > 0 {
rr := d.Answers[0]
if rdata := shortRData(&rr); rdata != "" {
return fmt.Sprintf("%04x %d/%d/%d %s %s%s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
}
}
return fmt.Sprintf("%04x %d/%d/%d%s (%d)", d.ID, anCount, nsCount, arCount, ede, plen)
}
// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914).
const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE)
// formatEDE returns " EDE=Name" for the first Extended DNS Error option
// found in the response, or empty string if none is present.
func formatEDE(d *layers.DNS) string {
for _, rr := range d.Additionals {
if rr.Type != layers.DNSTypeOPT {
continue
}
for _, opt := range rr.OPT {
if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 {
continue
}
info := binary.BigEndian.Uint16(opt.Data[:2])
name, ok := dns.ExtendedErrorCodeToString[info]
if !ok {
name = fmt.Sprintf("%d", info)
}
return " EDE=" + name
}
}
return ""
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
}
func shortRData(rr *layers.DNSResourceRecord) string {