Compare commits

..

8 Commits

Author SHA1 Message Date
crn4
8610a714be merge main 2026-05-04 18:18:32 +02:00
crn4
16cc204c89 upd for api spec 2026-05-04 17:59:22 +02:00
crn4
0a2b88a008 remove reseller invite link + priceID fix 2026-04-29 13:11:53 +02:00
crn4
dedb7e51e9 remove owner email from msp fields 2026-04-29 11:57:34 +02:00
crn4
b7c58dfdfb add email to msp model 2026-04-24 18:27:03 +02:00
crn4
d185b68fe1 resellers customer id field 2026-04-23 14:32:20 +02:00
crn4
49757e99e2 reseller openapi spec 2026-04-09 10:38:24 +02:00
crn4
daf7f41d69 openapi spec for reseller layer 2026-03-27 16:37:12 +01:00
48 changed files with 1219 additions and 2396 deletions

View File

@@ -1,130 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Ideas & Feature Requests
Use this category for feature requests, enhancements, integrations, and product ideas.
NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request.
Please search first and add your use case to an existing discussion when one already exists.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar requests.
required: true
- label: I checked the documentation to confirm this is not already supported.
required: true
- label: This is a product idea or enhancement request, not a support question.
required: true
- label: I removed or anonymized sensitive details from examples and screenshots.
required: true
- type: dropdown
id: area
attributes:
label: Product area
description: Select every area this request touches.
multiple: true
options:
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- Management service / API
- Signal service
- Relay
- DNS
- Routes / Exit nodes
- NetBird SSH
- Access control policies
- Posture checks
- Identity provider / SSO
- Self-hosting / Deployment
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: textarea
id: problem
attributes:
label: Problem or use case
description: What are you trying to accomplish, and what is difficult or impossible today?
placeholder: |
As a ...
I want to ...
Because ...
validations:
required: true
- type: textarea
id: proposal
attributes:
label: Proposed solution
description: Describe the behavior, workflow, API, UI, or integration you would like to see.
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives or workarounds considered
description: What have you tried today? Why is the current workaround not enough?
- type: textarea
id: impact
attributes:
label: Community impact and priority
description: Help us understand who benefits and how urgent this is.
placeholder: |
- Number of users/teams/peers affected:
- Deployment type: Cloud / self-hosted / both
- Frequency: daily / weekly / occasional
- Blocking production adoption? yes/no
- Related comments, discussions, or customer requests:
validations:
required: true
- type: textarea
id: examples
attributes:
label: Examples from other tools or products
description: If another tool solves this well, link or describe the behavior.
- type: textarea
id: security
attributes:
label: Security, privacy, and compatibility considerations
description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns.
- type: textarea
id: implementation
attributes:
label: Implementation ideas
description: Optional. If you are familiar with the codebase or API, share possible implementation notes.
- type: dropdown
id: contribution
attributes:
label: Are you willing to help?
options:
- Yes, I can submit a PR if the approach is accepted.
- Yes, I can test or validate a proposed implementation.
- Yes, I can provide more use-case details.
- Not at this time.
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add screenshots, diagrams, links, or anything else that helps explain the request.

View File

@@ -1,237 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Issue Triage
Use this category for reproducible bugs and regressions in NetBird.
The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue.
Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there.
Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues, including closed ones, and checked the relevant docs.
required: true
- label: I believe this is a product bug rather than a configuration or setup question.
required: true
- label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: area
attributes:
label: Affected area
description: Select every area this report touches.
multiple: true
options:
- Client / Agent
- Reverse Proxy
- CLI
- Desktop UI
- Mobile app
- Peer connectivity
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay / Signal / NAT traversal
- Login / Authentication / IdP
- Dashboard / Admin UI
- Management service / API
- Access control policies / Posture checks
- Self-hosting / Deployment
- Kubernetes / Operator
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure / environment I do not fully control
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
description: Select every environment involved in the reproduction.
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: textarea
id: version
attributes:
label: NetBird version and upgrade status
description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why.
placeholder: |
Example:
- Client: 0.30.2
- Management: 0.30.2
- Signal: 0.30.2
- Relay: 0.30.2
- Dashboard: 0.30.2
- Upgrade status: reproduced on current version / cannot upgrade because ...
validations:
required: true
- type: dropdown
id: regression
attributes:
label: Did this work before?
options:
- Yes, this worked before
- No, this never worked
- Not sure
validations:
required: true
- type: textarea
id: regression-details
attributes:
label: Regression details
description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes.
placeholder: |
- Last known working version:
- First known broken version:
- Recent changes:
- type: textarea
id: summary
attributes:
label: Summary
description: Briefly describe the reproducible bug.
placeholder: What is broken?
validations:
required: true
- type: textarea
id: current-behavior
attributes:
label: Current behavior
description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible.
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected behavior
description: What did you expect to happen instead?
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Steps to reproduce
description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps.
placeholder: |
1. Configure ...
2. Run ...
3. Observe ...
For intermittent issues:
- Trigger:
- Frequency:
- Timing/timestamps:
validations:
required: true
- type: textarea
id: environment
attributes:
label: Environment and topology
description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate.
placeholder: |
- Peer A:
- Peer B:
- Same LAN or different networks:
- NAT/CGNAT/corporate firewall/mobile network:
- Other VPN software:
- Firewall, DNS, or endpoint security software:
- Routes, DNS, policies, posture checks, or SSH rules involved:
- IdP, reverse proxy, or browser involved:
validations:
required: true
- type: textarea
id: self-hosted-details
attributes:
label: Self-hosted details, if available
description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access.
placeholder: |
- Deployment method: quickstart / Docker Compose / Helm / operator / custom
- Management/signal/relay/dashboard versions:
- Reverse proxy:
- IdP/provider:
- STUN/TURN/coturn/relay details:
- Relevant component logs:
- type: textarea
id: logs
attributes:
label: Logs, status output, or debug evidence
description: |
For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days.
For UI, dashboard, or documentation reports, leave the pre-filled `N/A`.
value: "N/A"
render: shell
validations:
required: true
- type: textarea
id: related-reports
attributes:
label: Related issues or discussions
description: Optional. Link similar reports you found while searching, if any.
placeholder: |
- Related issue/discussion:
- Why this may be the same or different:
- type: textarea
id: impact
attributes:
label: Impact
description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround?
placeholder: |
- Affected users/peers:
- Business or production impact:
- Workaround available:
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation.

View File

@@ -1,146 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Q&A / Support
Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage.
This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public.
If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar questions.
required: true
- label: I reviewed the relevant NetBird documentation or troubleshooting guide.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: topic
attributes:
label: Topic
multiple: true
options:
- Getting started
- Self-hosting
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay
- Access control policies
- Posture checks
- Identity provider / SSO
- API
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: input
id: version
attributes:
label: NetBird version
description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant.
placeholder: "Example: client 0.30.2, management 0.30.2"
- type: textarea
id: question
attributes:
label: Question
description: What are you trying to understand or accomplish?
placeholder: Describe your question clearly.
validations:
required: true
- type: textarea
id: goal
attributes:
label: Desired outcome
description: What would a successful answer help you do?
placeholder: |
I want to configure ...
I expected ...
I need help deciding ...
- type: textarea
id: attempted
attributes:
label: What have you tried?
description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried.
placeholder: |
- Read ...
- Ran ...
- Changed ...
- Observed ...
- type: textarea
id: environment
attributes:
label: Relevant environment details
description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer.
placeholder: |
- Deployment:
- Components involved:
- Network/topology:
- Related config:
- type: textarea
id: logs
attributes:
label: Logs or output
description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant.
render: shell
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links, diagrams, screenshots, or other details that may help the community answer.

View File

@@ -0,0 +1,71 @@
---
name: Bug/Issue report
about: Create a report to help us improve
title: ''
labels: ['triage-needed']
assignees: ''
---
**Describe the problem**
A clear and concise description of what the problem is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Are you using NetBird Cloud?**
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
**NetBird version**
`netbird version`
**Is any other VPN software installed?**
If yes, which one?
**Debug output**
To help us resolve the problem, please attach the following anonymized status output
netbird status -dA
Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,26 +1,14 @@
blank_issues_enabled: false blank_issues_enabled: true
contact_links: contact_links:
- name: Start an Issue Triage discussion - name: Community Support
url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage
about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue.
- name: Propose an idea or feature request
url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests
about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization.
- name: Ask a Q&A / Support question
url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support
about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage.
- name: Security vulnerability disclosure
url: https://github.com/netbirdio/netbird/security/policy
about: Please do not report security vulnerabilities in public issues or discussions.
- name: Community Support Forum
url: https://forum.netbird.io/ url: https://forum.netbird.io/
about: Community support forum. about: Community support forum
- name: Cloud Support - name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues url: https://docs.netbird.io/help/report-bug-issues
about: Contact NetBird for Cloud support. about: Contact us for support
- name: Client / Connection Troubleshooting - name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client url: https://docs.netbird.io/help/troubleshooting-client
about: See the client troubleshooting guide for common connectivity issues. about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting - name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting url: https://docs.netbird.io/selfhosted/troubleshooting
about: See the self-host troubleshooting guide for common deployment issues. about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ['feature-request']
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -1,128 +0,0 @@
name: Validated issue
description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work.
title: "[Validated]: "
body:
- type: markdown
attributes:
value: |
## Discussion-first issue policy
Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed.
Use this form when:
- A discussion has been validated and should become actionable work.
- A maintainer is opening internally validated work that can bypass the discussion-first flow.
Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions.
- type: checkboxes
id: validation-checks
attributes:
label: Validation checklist
options:
- label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer.
required: true
- label: The report has enough context for engineering to act on it without re-triaging from scratch.
required: true
- label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed.
required: true
- type: dropdown
id: issue-type
attributes:
label: Issue type
options:
- Bug / Regression
- Feature / Enhancement
- Documentation
- Maintenance / Refactor
- Cross-repository coordination
- Other
validations:
required: true
- type: input
id: source-discussion
attributes:
label: Source discussion
description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below.
placeholder: https://github.com/netbirdio/netbird/discussions/1234
validations:
required: true
- type: input
id: validation-owner
attributes:
label: Validation owner
description: GitHub handle of the DevRel team member or maintainer who validated this work.
placeholder: "@username"
validations:
required: true
- type: dropdown
id: target-repository
attributes:
label: Target repository
description: Where should the implementation work happen?
options:
- netbirdio/netbird
- netbirdio/dashboard
- netbirdio/kubernetes-operator
- netbirdio/docs
- Multiple repositories
- Unknown / needs routing
validations:
required: true
- type: textarea
id: summary
attributes:
label: Summary
description: Concise description of the validated work.
placeholder: What needs to be fixed, changed, documented, or built?
validations:
required: true
- type: textarea
id: evidence
attributes:
label: Validation evidence
description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes.
placeholder: |
- Reproduced by:
- Affected versions / platforms:
- Community signal:
- Related logs or screenshots:
validations:
required: true
- type: textarea
id: scope
attributes:
label: Proposed scope
description: Describe what is in scope and, if helpful, what is explicitly out of scope.
placeholder: |
In scope:
- ...
Out of scope:
- ...
validations:
required: true
- type: textarea
id: acceptance-criteria
attributes:
label: Acceptance criteria
description: What must be true for this issue to be closed?
placeholder: |
- [ ] ...
- [ ] ...
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes.

View File

@@ -58,11 +58,6 @@ linters:
govet: govet:
enable: enable:
- nilness - nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false enable-all: false
revive: revive:
rules: rules:

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -24,7 +23,6 @@ import (
func init() { func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
} }
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil { if err != nil {
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
} }
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {

View File

@@ -1,25 +0,0 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

View File

@@ -1,26 +0,0 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -39,9 +39,6 @@ const (
noBrowserFlag = "no-browser" noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login" noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile" profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
) )
@@ -51,7 +48,6 @@ var (
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool noBrowser bool
showQR bool
profileName string profileName string
configPath string configPath string
@@ -84,7 +80,6 @@ func init() {
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")

View File

@@ -21,7 +21,6 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
@@ -584,9 +583,6 @@ func isSensitiveEnvVar(key string) bool {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n") configContent.WriteString("NetBird Client Configuration:\n\n")
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
}
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface)) configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort)) configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil { if g.internalConfig.NetworkMonitor != nil {
@@ -611,12 +607,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil { if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
} }
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -643,7 +633,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
} }
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {

View File

@@ -5,21 +5,16 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
} }
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time" "time"
@@ -178,12 +177,7 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err return nil, nil, err
} }
dst := net.IPv4zero _, gateway, localIP, err = router.Route(net.IPv4zero)
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -202,12 +196,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err return nil, nil, err
} }
dst := net.IPv6zero _, gateway, localIP, err = router.Route(net.IPv6zero)
if runtime.GOOS == "linux" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}
_, gateway, localIP, err = router.Route(dst)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@@ -342,22 +342,6 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
if err != nil { if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err) return Nexthop{}, fmt.Errorf("new netroute: %w", err)
} }
// go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard
// client-side check. Substitute the lowest non-loopback address so the
// lookup falls through to the default route (::1 / 127.0.0.1 would match
// loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to
// the kernel and need no substitution.
if runtime.GOOS == "linux" && ip.IsUnspecified() {
if ip.Is6() {
// ::2
ip = netip.AddrFrom16([16]byte{15: 2})
} else {
// 0.0.0.1
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
}
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil { if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err) log.Debugf("Failed to get route for %s: %v", ip, err)

View File

@@ -354,13 +354,9 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
require.NoError(t, err, "Should be able to get IPv4 default route") require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4) t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
if testCase.prefix.Addr().Is6() && !testCase.expectError {
ensureIPv6DefaultRoute(t)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified()) initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() && if testCase.prefix.Addr().Is6() &&
initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") { (errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available") t.Skip("Skipping test as no ipv6 default route is available")
} }
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {

View File

@@ -1,30 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"bytes"
"os/exec"
"testing"
)
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("route already in table")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -1,41 +0,0 @@
//go:build linux && !android
package systemops
import (
"errors"
"net"
"syscall"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the
// loopback interface so route lookups for global IPv6 prefixes resolve in
// environments without v6 connectivity. Any pre-existing default route wins
// because of its lower metric.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
lo, err := netlink.LinkByName("lo")
require.NoError(t, err, "find loopback interface")
route := &netlink.Route{
Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)},
LinkIndex: lo.Attrs().Index,
Priority: 1 << 20,
}
if err := netlink.RouteAdd(route); err != nil {
if errors.Is(err, syscall.EEXIST) {
return
}
t.Skipf("install IPv6 fallback default route: %v", err)
}
t.Cleanup(func() {
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("delete IPv6 fallback default route: %v", err)
}
})
}

View File

@@ -1,34 +0,0 @@
//go:build windows
package systemops
import (
"bytes"
"os/exec"
"testing"
)
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop`
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("already exists")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop`
if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -27,10 +28,6 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
// Start begins monitoring the WireGuard interface. // Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop. // It relies on the provided context cancellation to stop.
//
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
// to avoid the allocation churn of repeatedly dumping the kernel link
// table; on other platforms it falls back to a low-frequency poll.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done) defer close(m.done)
@@ -59,7 +56,31 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
return watchInterface(ctx, ifaceName, expectedIndex) ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
} }
// getInterfaceIndex returns the index of a network interface by name. // getInterfaceIndex returns the index of a network interface by name.

View File

@@ -1,134 +0,0 @@
//go:build linux
package internal
import (
"context"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
// deletion or recreation of the WireGuard interface.
//
// The previous implementation polled net.InterfaceByName every 2 s, which
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
// entire kernel link table on every call. On hosts with many veth
// interfaces (containers, bridges) the resulting allocation churn was on
// the order of ~1 GB/day from this single ticker, which on small ARM
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
//
// The event-driven version below allocates only when the kernel actually
// publishes a link event for the tracked interface — typically zero
// allocations between events.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
done := make(chan struct{})
defer close(done)
// Buffer the channel to absorb event bursts (e.g. when many veth
// pairs are created/destroyed at once by container runtimes).
linkChan := make(chan netlink.LinkUpdate, 32)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
// Return shouldRestart=true so the engine recovers monitoring
// via triggerClientRestart instead of silently losing it for
// the rest of the process lifetime.
return true, fmt.Errorf("subscribe to link updates: %w", err)
}
// Race window: the interface could have been deleted (or recreated)
// between the initial getInterfaceIndex() in Start and LinkSubscribe
// completing its handshake with the kernel. Re-check explicitly so we
// do not block forever waiting for an event that already fired.
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
} else if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case update, ok := <-linkChan:
if !ok {
// The vishvananda/netlink subscription goroutine closes
// the channel on receive errors. Signal the engine to
// restart so monitoring is re-established instead of
// silently ending.
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
return true, fmt.Errorf("link subscription channel closed unexpectedly")
}
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
return true, err
}
}
}
}
// inspectLinkEvent classifies a single netlink link update against the
// tracked WireGuard interface. It returns (true, err) when the engine
// should restart monitoring; (false, nil) means the event is unrelated
// and the caller should keep waiting.
//
// The error component, when non-nil, describes the kernel-side reason
// (deletion or rename); the recreation case returns (true, nil) since
// no error condition is reported.
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
eventIndex := int(update.Index)
eventName := ""
if attrs := update.Attrs(); attrs != nil {
eventName = attrs.Name
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
case syscall.RTM_NEWLINK:
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
}
return false, nil
}
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
// tracked interface index.
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
if eventIndex != expectedIndex {
return false, nil
}
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted", ifaceName)
}
// inspectNewLink reports a restart when an RTM_NEWLINK either:
//
// 1. Introduces a link with our name at a different index (recreation
// after a delete), or
//
// 2. Reports a link still at our index but with a different name
// (in-place rename). The previous polling implementation caught
// this implicitly because net.InterfaceByName(ifaceName) would
// start failing; the event-driven version has to test it.
//
// Same name + same index is just a flag/state change on the existing
// interface and is ignored.
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
if eventName == ifaceName && eventIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, eventIndex)
return true, nil
}
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
ifaceName, eventName, expectedIndex)
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
}
return false, nil
}

View File

@@ -1,56 +0,0 @@
//go:build !linux
package internal
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// watchInterface polls net.InterfaceByName at a fixed interval to detect
// deletion or recreation of the WireGuard interface.
//
// This is the fallback used on non-Linux desktop and server platforms
// (darwin, windows, freebsd). It is also compiled on android and ios so
// the package builds on every supported GOOS, but it is never reached
// at runtime there because Start() in wg_iface_monitor.go exits early
// on mobile platforms.
//
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
// RTNLGRP_LINK netlink subscription instead, because on Linux
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
// dumps the entire kernel link table on every call and produces
// significant allocation churn (netbirdio/netbird#3678).
//
// Windows is also reported in #3678 as affected by RSS climb. A future
// follow-up could implement an event-driven watcher there using
// NotifyIpInterfaceChange from iphlpapi.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}

View File

@@ -224,20 +224,15 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error { func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
} }
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
} }
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath) log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil return nil
} }

20
go.mod
View File

@@ -17,8 +17,8 @@ require (
github.com/spf13/cobra v1.10.1 github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9 github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0 golang.org/x/crypto v0.49.0
golang.org/x/sys v0.43.0 golang.org/x/sys v0.42.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
@@ -68,10 +68,9 @@ require (
github.com/jackc/pgx/v5 v5.5.5 github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0 github.com/libdns/route53 v1.5.0
github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.4.0 github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
@@ -118,11 +117,11 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.34.0 golang.org/x/mod v0.33.0
golang.org/x/net v0.53.0 golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0 golang.org/x/sync v0.20.0
golang.org/x/term v0.42.0 golang.org/x/term v0.41.0
golang.org/x/time v0.15.0 golang.org/x/time v0.15.0
google.golang.org/api v0.276.0 google.golang.org/api v0.276.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
@@ -303,14 +302,13 @@ require (
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.36.0 // indirect golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.43.0 // indirect golang.org/x/tools v0.42.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
rsc.io/qr v0.2.0 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
@@ -323,6 +321,8 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

36
go.sum
View File

@@ -395,8 +395,6 @@ github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/libp2p/go-netroute v0.4.0 h1:sZZx9hyANYUx9PZyqcgE/E1GUG3iEtTZHUEvdtXT7/Q=
github.com/libp2p/go-netroute v0.4.0/go.mod h1:Nkd5ShYgSMS5MUKy/MU2T57xFoOKvvLR92Lic48LEyA=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
@@ -417,8 +415,6 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
@@ -455,6 +451,8 @@ github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
@@ -711,8 +709,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -729,8 +727,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -749,8 +747,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
@@ -801,8 +799,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -815,8 +813,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -828,8 +826,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -843,8 +841,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -917,5 +915,3 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -89,33 +89,21 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
} }
// UpdateConnector updates an existing connector in Dex storage. // UpdateConnector updates an existing connector in Dex storage.
// It overlays user-mutable config fields (issuer, clientID, clientSecret, // It merges incoming updates with existing values to prevent data loss on partial updates.
// redirectURI) onto the stored connector config, and updates the connector name
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
// partial updates preserve create-time defaults such as scopes, claimMapping,
// and userIDKey.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) { oldCfg, err := p.parseStorageConnector(old)
return storage.Connector{}, errors.New("connector type change not allowed")
}
configData, err := overlayConnectorConfig(old.Config, cfg)
if err != nil { if err != nil {
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err) return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
} }
name := cfg.Name mergeConnectorConfig(cfg, oldCfg)
if name == "" {
name = old.Name
}
return storage.Connector{ storageConn, err := p.buildStorageConnector(cfg)
ID: cfg.ID, if err != nil {
Type: old.Type, return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
Name: name, }
Config: configData, return storageConn, nil
}, nil
}); err != nil { }); err != nil {
return fmt.Errorf("failed to update connector: %w", err) return fmt.Errorf("failed to update connector: %w", err)
} }
@@ -124,27 +112,23 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
return nil return nil
} }
// overlayConnectorConfig writes only the user-mutable fields onto the existing // mergeConnectorConfig preserves existing values for empty fields in the update.
// stored config, preserving every other field (scopes, claimMapping, userIDKey, func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
// insecure flags, etc.). Empty fields on cfg leave the existing value alone. if cfg.ClientSecret == "" {
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) { cfg.ClientSecret = oldCfg.ClientSecret
var m map[string]any
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
return nil, err
} }
if cfg.Issuer != "" { if cfg.RedirectURI == "" {
m["issuer"] = cfg.Issuer cfg.RedirectURI = oldCfg.RedirectURI
} }
if cfg.ClientID != "" { if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
m["clientID"] = cfg.ClientID cfg.Issuer = oldCfg.Issuer
} }
if cfg.ClientSecret != "" { if cfg.ClientID == "" {
m["clientSecret"] = cfg.ClientSecret cfg.ClientID = oldCfg.ClientID
} }
if cfg.RedirectURI != "" { if cfg.Name == "" {
m["redirectURI"] = cfg.RedirectURI cfg.Name = oldCfg.Name
} }
return encodeConnectorConfig(m)
} }
// DeleteConnector removes a connector from Dex storage. // DeleteConnector removes a connector from Dex storage.
@@ -232,10 +216,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["getUserInfo"] = true oidcConfig["getUserInfo"] = true
case "entra": case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
// Entra issues sub as a per-app pairwise identifier that does not match
// the stable Object ID.
oidcConfig["userIDKey"] = "oid"
case "okta": case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid": case "pocketid":

View File

@@ -1,205 +0,0 @@
package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestProvider(t *testing.T) (*Provider, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return &Provider{storage: s, logger: logger}, func() {
_ = s.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
cfg := &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "client-secret",
}
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
}
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
// ensures the Entra userIDKey override does not leak into other OIDC providers,
// which already use a stable sub claim.
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
t.Run(typ, func(t *testing.T) {
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
_, ok := m["userIDKey"]
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
})
}
}
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
created, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "old-secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
require.Equal(t, "entra-test", created.ID)
// Rotate only the client secret.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
}
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
// Seed a connector directly into storage without userIDKey
preFixConfig, err := json.Marshal(map[string]any{
"issuer": "https://login.microsoftonline.com/tid/v2.0",
"clientID": "client-id",
"clientSecret": "old-secret",
"redirectURI": "https://example.com/oauth2/callback",
"scopes": []string{"openid", "profile", "email"},
"claimMapping": map[string]string{"email": "preferred_username"},
})
require.NoError(t, err)
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
ID: "entra-prefix",
Type: "oidc",
Name: "Entra",
Config: preFixConfig,
}))
// Rotate client secret via UpdateConnector.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-prefix",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"])
_, has := m["userIDKey"]
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
}
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
// Attempt to switch the connector to okta.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "okta",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "connector type change not allowed")
// stored connector type/config unchanged after the rejected update.
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
assert.Equal(t, "oidc", conn.Type)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "oid", m["userIDKey"])
}
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/old/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
Issuer: "https://login.microsoftonline.com/new/v2.0",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
}

View File

@@ -11,9 +11,9 @@ import (
// Manager defines the interface for proxy operations // Manager defines the interface for proxy operations
type Manager interface { type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID, sessionID string) error Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, p *Proxy) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error) GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -13,8 +13,7 @@ import (
// store defines the interface for proxy persistence operations // store defines the interface for proxy persistence operations
type store interface { type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
@@ -44,7 +43,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database. // Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them. // capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
now := time.Now() now := time.Now()
var caps proxy.Capabilities var caps proxy.Capabilities
if capabilities != nil { if capabilities != nil {
@@ -52,7 +51,6 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
} }
p := &proxy.Proxy{ p := &proxy.Proxy{
ID: proxyID, ID: proxyID,
SessionID: sessionID,
ClusterAddress: clusterAddress, ClusterAddress: clusterAddress,
IPAddress: ipAddress, IPAddress: ipAddress,
LastSeen: now, LastSeen: now,
@@ -63,42 +61,48 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
if err := m.store.SaveProxy(ctx, p); err != nil { if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"sessionID": sessionID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return p, nil
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err return err
} }
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID, "proxyID": proxyID,
"sessionID": sessionID, "clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected") }).Info("proxy disconnected")
return nil return nil
} }
// Heartbeat updates the proxy's last seen timestamp. // Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err return err
} }
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID) log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
m.metrics.IncrementProxyHeartbeatCount() m.metrics.IncrementProxyHeartbeatCount()
return nil return nil
} }

View File

@@ -93,32 +93,31 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
} }
// Connect mocks base method. // Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(*Proxy) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// Connect indicates an expected call of Connect. // Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
} }
// Disconnect mocks base method. // Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error { func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID) ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Disconnect indicates an expected call of Disconnect. // Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
} }
// GetActiveClusterAddresses mocks base method. // GetActiveClusterAddresses mocks base method.
@@ -152,17 +151,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
} }
// Heartbeat mocks base method. // Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error { func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, p) ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Heartbeat indicates an expected call of Heartbeat. // Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
} }
// MockController is a mock of Controller interface. // MockController is a mock of Controller interface.

View File

@@ -18,7 +18,6 @@ type Capabilities struct {
// Proxy represents a reverse proxy instance // Proxy represents a reverse proxy instance
type Proxy struct { type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"` ID string `gorm:"primaryKey;type:varchar(255)"`
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"` IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`

View File

@@ -11,14 +11,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -84,44 +81,14 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers // Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc cancel context.CancelFunc
} }
const pkceVerifierTTL = 10 * time.Minute const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy // proxyConnection represents a connected proxy
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
sessionID string
address string address string
capabilities *proto.ProxyCapabilities capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer stream proto.ProxyService_GetMappingUpdateServer
@@ -141,7 +108,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr, proxyManager: proxyMgr,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel, cancel: cancel,
} }
go s.cleanupStaleProxies(ctx) go s.cleanupStaleProxies(ctx)
@@ -200,22 +166,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid") return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
} }
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
connCtx, cancel := context.WithCancel(ctx) connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress, address: proxyAddress,
capabilities: req.GetCapabilities(), capabilities: req.GetCapabilities(),
stream: stream, stream: stream,
@@ -224,6 +177,11 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel, cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities // Register proxy in database with capabilities
var caps *proxy.Capabilities var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil { if c := req.GetCapabilities(); c != nil {
@@ -233,84 +191,65 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
SupportsCrowdsec: c.SupportsCrowdsec, SupportsCrowdsec: c.SupportsCrowdsec,
} }
} }
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
cancel() s.connectedProxies.Delete(proxyID)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
return status.Errorf(codes.Internal, "register proxy in database: %v", err) return status.Errorf(codes.Internal, "register proxy in database: %v", err)
} }
s.connectedProxies.Store(proxyID, conn) log.WithFields(log.Fields{
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { "proxy_id": proxyID,
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) "address": proxyAddress,
"cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
} }
if err := s.sendSnapshot(ctx, conn); err != nil { s.connectedProxies.Delete(proxyID)
if s.connectedProxies.CompareAndDelete(proxyID, conn) { if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
} }
cancel() cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { log.Infof("Proxy %s disconnected", proxyID)
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) }()
}
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
} }
errChan := make(chan error, 2) errChan := make(chan error, 2)
go s.sender(conn, errChan) go s.sender(conn, errChan)
log.WithFields(log.Fields{ // Start heartbeat goroutine
"proxy_id": proxyID, go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, proxyRecord)
select { select {
case err := <-errChan: case err := <-errChan:
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done(): case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err() return connCtx.Err()
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute // heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, p); err != nil { if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
} }
case <-ctx.Done(): case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return return
} }
} }
@@ -328,27 +267,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err return err
} }
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 { if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true, InitialSyncComplete: true,
}); err != nil { }); err != nil {
return fmt.Errorf("send snapshot completion: %w", err) return fmt.Errorf("send snapshot completion: %w", err)
} }
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
} }
return nil return nil
@@ -366,9 +300,13 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue continue
} }
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
if err != nil { if err != nil {
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
} }
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -448,16 +386,13 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
conn := value.(*proxyConnection) conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID) resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil { if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true return true
} }
select { select {
case conn.sendChan <- resp: case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID) log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default: default:
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID) log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
conn.cancel()
} }
return true return true
}) })
@@ -537,16 +472,13 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
} }
msg := s.perProxyMessage(updateResponse, proxyID) msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil { if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue continue
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default: default:
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
conn.cancel()
} }
} }
} }
@@ -572,8 +504,7 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is // create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal. // used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails; the caller must disconnect the // Returns nil if token generation fails (the proxy should be skipped).
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping { for _, mapping := range update.Mapping {
@@ -582,7 +513,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue continue
} }
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL()) token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
if err != nil { if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil return nil

View File

@@ -1,174 +0,0 @@
package grpc
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// recordingStream captures all messages sent via Send so tests can inspect
// batching behaviour without a real gRPC transport.
type recordingStream struct {
grpc.ServerStream
messages []*proto.GetMappingUpdateResponse
}
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
s.messages = append(s.messages, m)
return nil
}
func (s *recordingStream) Context() context.Context { return context.Background() }
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
func (s *recordingStream) SetTrailer(metadata.MD) {}
func (s *recordingStream) SendMsg(any) error { return nil }
func (s *recordingStream) RecvMsg(any) error { return nil }
// makeServices creates n enabled services assigned to the given cluster.
func makeServices(n int, cluster string) []*rpservice.Service {
services := make([]*rpservice.Service, n)
for i := range n {
services[i] = &rpservice.Service{
ID: fmt.Sprintf("svc-%d", i),
AccountID: "acct-1",
Name: fmt.Sprintf("svc-%d", i),
Domain: fmt.Sprintf("svc-%d.example.com", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
},
}
}
return services
}
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
t.Helper()
s := &ProxyServiceServer{
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
snapshotBatchSize: batchSize,
}
s.SetProxyController(newTestProxyController())
return s
}
func TestSendSnapshot_BatchesMappings(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
stream: stream,
}
err := s.sendSnapshot(context.Background(), conn)
require.NoError(t, err)
// Expect ceil(7/3) = 3 messages
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
assert.Len(t, stream.messages[1].Mapping, 3)
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
assert.Len(t, stream.messages[2].Mapping, 1)
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
// Verify all service IDs are present exactly once
seen := make(map[string]bool)
for _, msg := range stream.messages {
for _, m := range msg.Mapping {
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
seen[m.Id] = true
}
}
assert.Len(t, seen, totalServices)
}
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 6 // exactly 2 batches
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 2)
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete)
assert.Len(t, stream.messages[1].Mapping, 3)
assert.True(t, stream.messages[1].InitialSyncComplete)
}
func TestSendSnapshot_SingleBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
assert.Len(t, stream.messages[0].Mapping, totalServices)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}

View File

@@ -85,14 +85,11 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10) ch := make(chan *proto.GetMappingUpdateResponse, 10)
ctx, cancel := context.WithCancel(context.Background())
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
address: clusterAddr, address: clusterAddr,
capabilities: caps, capabilities: caps,
sendChan: ch, sendChan: ch,
ctx: ctx,
cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)

View File

@@ -1040,13 +1040,6 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
for attempt := 1; attempt <= maxAttempts; attempt++ { for attempt := 1; attempt <= maxAttempts; attempt++ {
if am.isCacheFresh(ctx, accountUsers, data) { if am.isCacheFresh(ctx, accountUsers, data) {
// Catch the silent vacuous-fresh case: empty accountUsers map + empty cache data
// → isCacheFresh returns true without iterating, skipping refreshCache,
// returning empty data. This matters when the account is mostly integration
// (SCIM) users and the InternalCache has been flushed.
if len(accountUsers) == 0 && len(data) == 0 {
log.WithContext(ctx).Warnf("lookupCache VACUOUS FRESH: accountUsers map is empty AND cache is empty for account %s — returning empty data without triggering loadAccount", accountID)
}
return data, nil return data, nil
} }

View File

@@ -7,8 +7,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -42,6 +45,11 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
// getAllCountries retrieves a list of all countries // getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil { if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready // TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -63,6 +71,11 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
// getCitiesByCountry retrieves a list of cities based on the given country code // getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
countryCode := vars["country"] countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) { if !countryCodeRegex.MatchString(countryCode) {
@@ -89,6 +102,27 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities) util.WriteJSONObject(r.Context(), w, cities)
} }
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country { func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{ return api.Country{
CountryName: country.CountryName, CountryName: country.CountryName,

View File

@@ -5437,35 +5437,13 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
return nil return nil
} }
// DisconnectProxy marks a proxy as disconnected only if the session ID matches. // UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
// This prevents a slow-to-close old session from overwriting a newer reconnection. func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
now := time.Now()
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", proxyID, sessionID).
Updates(map[string]any{
"status": "disconnected",
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
now := time.Now() now := time.Now()
result := s.db. result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", p.ID, p.SessionID). Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now) Update("last_seen", now)
if result.Error != nil { if result.Error != nil {
@@ -5474,11 +5452,17 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
p.LastSeen = now p := &proxy.Proxy{
p.ConnectedAt = &now ID: proxyID,
p.Status = "connected" ClusterAddress: clusterAddress,
if err := s.db.Create(p).Error; err != nil { IPAddress: ipAddress,
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
} }
} }

View File

@@ -284,8 +284,7 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -178,7 +178,6 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
} }
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2800,20 +2799,6 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
} }
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
}
// SaveProxyAccessToken mocks base method. // SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -3010,17 +2995,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
} }
// UpdateProxyHeartbeat mocks base method. // UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p) ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call { func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
} }
// UpdateService mocks base method. // UpdateService mocks base method.

View File

@@ -144,11 +144,8 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken() _, plainToken, err := GenerateInviteToken()
require.NoError(t, err) require.NoError(t, err)
replacement := "X" // Modify one character in the secret part
if plainToken[5] == 'X' { modifiedToken := plainToken[:5] + "X" + plainToken[6:]
replacement = "Y"
}
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
err = ValidateInviteToken(modifiedToken) err = ValidateInviteToken(modifiedToken)
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -1061,28 +1061,15 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
queriedUsers = append(queriedUsers, usersFromIntegration...) queriedUsers = append(queriedUsers, usersFromIntegration...)
} }
idpManagerNil := isNil(am.idpManager)
idpManagerEmbedded := !idpManagerNil && IsEmbeddedIdp(am.idpManager)
userInfosMap := make(map[string]*types.UserInfo) userInfosMap := make(map[string]*types.UserInfo)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 { if len(queriedUsers) == 0 {
var earlyReturnEmpty int
var earlyReturnEmptySamples []string
for _, accountUser := range accountUsers { for _, accountUser := range accountUsers {
info, err := accountUser.ToUserInfo(nil) info, err := accountUser.ToUserInfo(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !accountUser.IsServiceUser && (info.Email == "" || info.Name == "") {
earlyReturnEmpty++
if len(earlyReturnEmptySamples) < 50 {
earlyReturnEmptySamples = append(earlyReturnEmptySamples,
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
accountUser.Id, accountUser.Issued, accountUser.Email, accountUser.Name))
}
}
// Try to decode Dex user ID to extract the IdP ID (connector ID) // Try to decode Dex user ID to extract the IdP ID (connector ID)
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" { if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
info.IdPID = connectorID info.IdPID = connectorID
@@ -1090,61 +1077,17 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
userInfosMap[accountUser.Id] = info userInfosMap[accountUser.Id] = info
} }
log.WithContext(ctx).Warnf("BuildUserInfosForAccount EARLY RETURN: queriedUsers empty, returning %d users with DB-only data (idpManagerNil=%v, idpManagerEmbedded=%v). %d non-service users have empty email/name in DB. Samples: %v",
len(accountUsers), idpManagerNil, idpManagerEmbedded, earlyReturnEmpty, earlyReturnEmptySamples)
// Same canonical-truth final scan, also on the early-return path
var finalEmptyEmail, finalEmptyName int
var finalEmptySamples []string
for id, info := range userInfosMap {
if info.IsServiceUser {
continue
}
if info.Email == "" {
finalEmptyEmail++
}
if info.Name == "" {
finalEmptyName++
}
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
finalEmptySamples = append(finalEmptySamples,
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
}
}
if finalEmptyEmail > 0 || finalEmptyName > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL (early-return path): returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
}
return userInfosMap, nil return userInfosMap, nil
} }
var cacheHitEmpty, fallbackMiss int
var cacheHitEmptySamples, fallbackMissSamples []string
for _, localUser := range accountUsers { for _, localUser := range accountUsers {
var info *types.UserInfo var info *types.UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
if !localUser.IsServiceUser && (queriedUser.Email == "" || queriedUser.Name == "") {
cacheHitEmpty++
if len(cacheHitEmptySamples) < 50 {
cacheHitEmptySamples = append(cacheHitEmptySamples,
fmt.Sprintf("%s(cache.email=%q,cache.name=%q,db.email=%q,db.name=%q)",
localUser.Id, queriedUser.Email, queriedUser.Name, localUser.Email, localUser.Name))
}
}
info, err = localUser.ToUserInfo(queriedUser) info, err = localUser.ToUserInfo(queriedUser)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
if !localUser.IsServiceUser {
fallbackMiss++
if len(fallbackMissSamples) < 50 {
fallbackMissSamples = append(fallbackMissSamples,
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
localUser.Id, localUser.Issued, localUser.Email, localUser.Name))
}
}
name := "" name := ""
if localUser.IsServiceUser { if localUser.IsServiceUser {
name = localUser.ServiceUserName name = localUser.ServiceUserName
@@ -1168,40 +1111,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
userInfosMap[info.ID] = info userInfosMap[info.ID] = info
} }
if cacheHitEmpty > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d users found in cache with empty email or name (cache pollution). Samples: %v",
cacheHitEmpty, cacheHitEmptySamples)
}
if fallbackMiss > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d non-service users missed both caches (will get empty Name in API response from fallback). Samples: %v",
fallbackMiss, fallbackMissSamples)
}
// Canonical-truth log: scan what we are actually about to return to the handler.
// This catches empties from any code path (early-return, cache-hit-empty, fallback-miss,
// or anything we haven't identified yet).
var finalEmptyEmail, finalEmptyName int
var finalEmptySamples []string
for id, info := range userInfosMap {
if info.IsServiceUser {
continue
}
if info.Email == "" {
finalEmptyEmail++
}
if info.Name == "" {
finalEmptyName++
}
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
finalEmptySamples = append(finalEmptySamples,
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
}
}
if finalEmptyEmail > 0 || finalEmptyName > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL: returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
}
return userInfosMap, nil return userInfosMap, nil
} }

View File

@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing. // testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{} type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error { func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil return nil
} }
@@ -364,16 +364,14 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping) mappingsByID := make(map[string]*proto.ProxyMapping)
for { for i := 0; i < 2; i++ {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
for _, m := range msg.GetMapping() { for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m mappingsByID[m.GetId()] = m
} }
if msg.GetInitialSyncComplete() {
break
}
} }
// Should receive 2 mappings total // Should receive 2 mappings total
@@ -413,14 +411,12 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0) mappings := make([]*proto.ProxyMapping, 0)
for { for i := 0; i < 2; i++ {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...) mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
} }
// Should receive the 2 mappings matching the cluster // Should receive the 2 mappings matching the cluster
@@ -444,15 +440,13 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
clusterAddress := "test.proxy.io" clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect" proxyID := "test-proxy-reconnect"
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { // Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping var mappings []*proto.ProxyMapping
for { for i := 0; i < count; i++ {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...) mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
} }
return mappings return mappings
} }
@@ -466,7 +460,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
firstMappings := receiveMappings(stream1) firstMappings := receiveMappings(stream1, 2)
cancel1() cancel1()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -482,7 +476,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
secondMappings := receiveMappings(stream2) secondMappings := receiveMappings(stream2, 2)
// Should receive the same mappings // Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings), assert.Equal(t, len(firstMappings), len(secondMappings),
@@ -548,14 +542,12 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
} }
} }
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for { for i := 0; i < 2; i++ {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
applyMappings(msg.GetMapping()) applyMappings(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
} }
} }
@@ -644,14 +636,12 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0 count := 0
for { for i := 0; i < 2; i++ {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
count += len(msg.GetMapping()) count += len(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
} }
mu.Lock() mu.Lock()
@@ -666,78 +656,3 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID) assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
} }
} }
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
// when a proxy reconnects before the old stream's cleanup runs, the new
// connection is NOT removed by the stale defer.
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-race"
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithCancel(context.Background())
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for {
msg, err := stream1.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should be registered after first connection")
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for {
msg, err := stream2.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
cancel1()
time.Sleep(200 * time.Millisecond)
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "rp-1",
AccountId: "test-account-1",
Domain: "app1.test.proxy.io",
}},
})
msg, err := stream2.Recv()
require.NoError(t, err, "new stream should still receive updates")
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
}

View File

@@ -943,8 +943,6 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
operation := func() error { operation := func() error {
s.Logger.Debug("connecting to management mapping stream") s.Logger.Debug("connecting to management mapping stream")
initialSyncDone = false
if s.healthChecker != nil { if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false) s.healthChecker.SetManagementConnected(false)
} }
@@ -1002,11 +1000,6 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return ctx.Err() return ctx.Err()
} }
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
for { for {
// Check for context completion to gracefully shutdown. // Check for context completion to gracefully shutdown.
select { select {
@@ -1027,13 +1020,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.processMappings(ctx, msg.GetMapping()) s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed") s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone { if !*initialSyncDone && msg.GetInitialSyncComplete() {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil { if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete() s.healthChecker.SetInitialSyncComplete()
} }
@@ -1042,28 +1029,6 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
} }
} }
} }
}
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
s.portMu.RLock()
var stale []*proto.ProxyMapping
for svcID, mapping := range s.lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, mapping)
}
}
s.portMu.RUnlock()
for _, mapping := range stale {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Info("Removing stale mapping absent from snapshot")
s.removeMapping(ctx, mapping)
}
} }
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {

View File

@@ -1,227 +0,0 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
// so we can verify it without triggering removeMapping (which requires full
// server wiring). This keeps the test focused on the detection algorithm.
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
var stale []types.ServiceID
for svcID := range lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, svcID)
}
}
return stale
}
// TestStaleDetection_PartialOverlap verifies that only services absent from
// the snapshot are flagged as stale.
func TestStaleDetection_PartialOverlap(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
"svc-stale-a": {Id: "svc-stale-a"},
"svc-stale-b": {Id: "svc-stale-b"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
"svc-3": {}, // new service, not in local
}
stale := collectStaleIDs(local, snapshot)
assert.Len(t, stale, 2)
staleSet := make(map[types.ServiceID]struct{})
for _, id := range stale {
staleSet[id] = struct{}{}
}
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
}
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
func TestStaleDetection_AllStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
assert.Len(t, stale, 2)
}
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
func TestStaleDetection_NoneStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
stale := collectStaleIDs(local, snapshot)
assert.Empty(t, stale)
}
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
func TestStaleDetection_EmptyLocal(t *testing.T) {
stale := collectStaleIDs(
map[types.ServiceID]*proto.ProxyMapping{},
map[types.ServiceID]struct{}{"svc-1": {}},
)
assert.Empty(t, stale)
}
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
// local mappings are present in the snapshot (removeMapping is never called).
func TestReconcileSnapshot_NoStale(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
snapshotIDs := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
// This should not panic — no stale entries means removeMapping is never called.
s.reconcileSnapshot(context.Background(), snapshotIDs)
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
}
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
// no local mappings.
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
assert.Empty(t, s.lastMappings)
}
// --- handleMappingStream tests for batched snapshot ID accumulation ---
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
// marked done only after the final InitialSyncComplete message, even when
// the snapshot arrives in multiple batches.
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // batch 1: no sync-complete
{}, // batch 2: no sync-complete
{InitialSyncComplete: true}, // batch 3: sync done
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync should be marked done after final batch")
}
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
// arriving after InitialSyncComplete do not trigger a second reconciliation.
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// Simulate state left over from a previous sync.
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // post-sync empty message — must not reconcile
},
}
syncDone := true // sync already completed in a previous stream
err := s.handleMappingStream(context.Background(), stream, &syncDone)
require.NoError(t, err)
assert.Len(t, s.lastMappings, 2,
"post-sync messages must not trigger reconciliation — all entries should survive")
}
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
// stream closes before sync completes, no reconciliation occurs.
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
stream := &mockMappingStream{} // no messages → immediate EOF
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
}
// mockErrRecvStream returns an error on the second Recv to verify
// handleMappingStream returns without completing sync.
type mockErrRecvStream struct {
mockMappingStream
calls int
}
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
m.calls++
if m.calls == 1 {
return &proto.GetMappingUpdateResponse{}, nil
}
return nil, io.ErrUnexpectedEOF
}
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
syncDone := false
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
assert.Error(t, err)
assert.False(t, syncDone)
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
}

View File

@@ -3713,6 +3713,177 @@ components:
example: "https://invoice.stripe.com/i/acct_1M2DaBKina4I2KUb/test_YWNjdF8xTTJEdVBLaW5hM0kyS1ViLF1SeFpQdEJZd3lUOGNEajNqeWdrdXY2RFM4aHcyCnpsLDEzMjg3GTgyNQ02000JoIHc1X?s=db" example: "https://invoice.stripe.com/i/acct_1M2DaBKina4I2KUb/test_YWNjdF8xTTJEdVBLaW5hM0kyS1ViLF1SeFpQdEJZd3lUOGNEajNqeWdrdXY2RFM4aHcyCnpsLDEzMjg3GTgyNQ02000JoIHc1X?s=db"
required: required:
- url - url
MSPStatusResponse:
type: object
properties:
id:
type: string
description: Tenant account ID (present only for tenants)
example: ch8i4ug6lnn4g9hqv7m0
parent:
type: string
description: Parent MSP account ID (present only for tenants)
example: ch8i4ug6lnn4g9hqv7m1
activated_at:
type: string
description: MSP or Tenant activation timestamp in RFC3339 format
example: "2024-01-01T00:00:00Z"
invited_at:
type: string
description: Tenant invitation timestamp in RFC3339 format (present only for tenants)
example: "2024-01-01T00:00:00Z"
status:
type: string
description: Tenant status (present only for tenants)
enum: ["existing", "invited", "pending", "active"]
example: active
name:
type: string
description: MSP name (present only for MSP accounts)
example: "My MSP"
domain:
type: string
description: MSP domain (present only for MSP accounts)
example: "msp.com"
has_reseller:
type: boolean
description: Whether the MSP has a reseller (present only for MSP accounts)
default: false
example: false
is_reseller:
type: boolean
description: Whether the account is a reseller
default: false
example: false
parent_name:
type: string
description: Parent MSP name (present only for tenants)
example: "My MSP"
parent_domain:
type: string
description: Parent MSP domain (present only for tenants)
example: "msp.com"
parent_owner_name:
type: string
description: Parent MSP owner name
example: "John Doe"
parent_owner_email:
type: string
description: Parent MSP owner email
example: "john@msp.com"
ResellerStatusResponse:
type: object
properties:
activated_at:
type: string
description: Reseller activation timestamp in RFC3339 format
example: "2024-01-01T00:00:00Z"
name:
type: string
description: Reseller name
example: "My Reseller"
domain:
type: string
description: Reseller domain
example: "reseller.com"
parent_owner_name:
type: string
description: Reseller owner name
example: "John Doe"
parent_owner_email:
type: string
description: Reseller owner email
example: "john@reseller.com"
ResellerMSPResponse:
type: object
description: An MSP account managed (or invited) by a reseller.
properties:
id:
type: string
description: The MSP account ID.
example: ch8i4ug6lnn4g9hqv7m0
name:
type: string
description: Display name of the MSP.
example: "Partner MSP"
domain:
type: string
description: The MSP account domain.
example: "partner-msp.com"
status:
type: string
description: |
Lifecycle status of the reseller↔MSP relationship.
* `existing` — MSP exists in the system but has not yet been invited.
* `invited` — Reseller has invited the MSP; MSP user has not accepted.
* `active` — MSP user accepted; reseller manages the account.
enum:
- existing
- invited
- active
example: "active"
tenant_number:
type: integer
description: Number of manageable (active or pending) tenants under this MSP.
example: 12
reseller_customer_id:
type: string
description: Reseller's internal customer reference for this MSP.
example: "CUST-12345"
activated_at:
type: string
format: date-time
description: When the MSP account was activated.
example: "2026-04-01T12:00:00Z"
invited_at:
type: string
format: date-time
description: When the reseller invited the MSP (set only if invited).
example: "2026-04-15T09:30:00Z"
required:
- id
- name
- domain
- status
- tenant_number
GetResellerMSPsResponse:
type: array
items:
$ref: "#/components/schemas/ResellerMSPResponse"
CreateResellerMSPRequest:
type: object
properties:
name:
type: string
description: The name for the MSP
example: "New Partner MSP"
domain:
type: string
description: The domain for the MSP
example: "new-partner.com"
priceID:
type: string
description: Stripe price ID to set up managed subscription for the MSP
example: "price_1234"
reseller_customer_id:
type: string
description: Reseller's internal customer reference for this MSP
example: "CUST-12345"
required:
- name
- domain
UpdateResellerMSPRequest:
type: object
description: Mutable fields the reseller can update on a managed MSP.
properties:
name:
type: string
description: New display name for the MSP. Whitespace is trimmed; empty/whitespace-only values are rejected.
example: "Renamed MSP"
reseller_customer_id:
type: string
description: New reseller customer reference. Pass empty string to clear.
example: "CUST-67890"
CreateTenantRequest: CreateTenantRequest:
type: object type: object
properties: properties:
@@ -8845,6 +9016,241 @@ paths:
$ref: "#/components/responses/requires_authentication" $ref: "#/components/responses/requires_authentication"
"500": "500":
$ref: "#/components/responses/internal_error" $ref: "#/components/responses/internal_error"
/api/integrations/msp:
get:
summary: Get MSP or Tenant status
description: Returns the MSP, Tenant, or Reseller status of the authenticated account
tags:
- MSP
responses:
"200":
description: MSP or Tenant status response
content:
application/json:
schema:
$ref: "#/components/schemas/MSPStatusResponse"
"401":
$ref: "#/components/responses/requires_authentication"
"500":
$ref: "#/components/responses/internal_error"
post:
summary: Create MSP account
description: Activates the authenticated account as an MSP
tags:
- MSP
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
invite:
type: string
description: The invite code
example: "705860a1-27a3-4976-bf63-c5cd2fc1582b"
required:
- invite
responses:
"200":
description: MSP account created or already exists
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"412":
description: MSP account requirements not met
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller:
get:
summary: Get Reseller status
description: Returns the reseller status of the authenticated account
tags:
- MSP
responses:
"200":
description: Reseller status response
content:
application/json:
schema:
$ref: "#/components/schemas/ResellerStatusResponse"
"401":
$ref: "#/components/responses/requires_authentication"
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/msps:
get:
summary: List MSPs under reseller
tags:
- MSP
responses:
"200":
description: List of MSPs managed by the reseller
content:
application/json:
schema:
$ref: "#/components/schemas/GetResellerMSPsResponse"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"500":
$ref: "#/components/responses/internal_error"
post:
summary: Create MSP under reseller
description: Creates a new MSP account managed by the reseller. No domain validation required.
tags:
- MSP
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CreateResellerMSPRequest"
responses:
"200":
description: MSP created successfully
content:
application/json:
schema:
$ref: "#/components/schemas/ResellerMSPResponse"
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"409":
description: MSP already exists for this domain
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/msps/{id}:
put:
summary: Update MSP fields managed by reseller
description: Update editable reseller-level fields on an MSP (e.g. reseller_customer_id)
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID to update
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UpdateResellerMSPRequest"
responses:
"200":
description: MSP updated successfully
content:
application/json:
schema:
$ref: "#/components/schemas/ResellerMSPResponse"
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found or not managed by this reseller
"500":
$ref: "#/components/responses/internal_error"
delete:
summary: Unlink MSP from reseller
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID to unlink
responses:
"200":
description: MSP unlinked successfully
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found or not managed by this reseller
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/msps/{id}/invite:
post:
summary: Invite existing MSP to reseller
description: Sends an invitation to an existing MSP to join the reseller
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID to invite
responses:
"200":
description: Invitation sent successfully
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found
"412":
description: MSP is already managed by a reseller
"500":
$ref: "#/components/responses/internal_error"
put:
summary: Accept or decline reseller invitation
description: MSP owner accepts or declines an invitation from a reseller
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
value:
type: string
description: Accept or decline the invitation
enum:
- accept
- decline
required:
- value
responses:
"200":
description: Invitation response processed
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found or no pending invitation
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/tenants: /api/integrations/msp/tenants:
get: get:
summary: Get MSP tenants summary: Get MSP tenants
@@ -9082,6 +9488,165 @@ paths:
description: The tenant was not found description: The tenant was not found
"500": "500":
$ref: "#/components/responses/internal_error" $ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/msps/{id}/subscription:
post:
summary: Create MSP subscription under reseller
description: Creates a managed Stripe subscription for an MSP managed by the reseller
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
priceID:
type: string
description: The Stripe price ID for the plan
required:
- priceID
responses:
"200":
description: Subscription created successfully
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found or not managed by this reseller
"412":
description: MSP already has an active subscription or reseller has no subscription
"500":
$ref: "#/components/responses/internal_error"
put:
summary: Change MSP subscription plan
description: Changes the plan of an MSP subscription managed by the reseller
tags:
- MSP
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The MSP account ID
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
priceID:
type: string
description: The new Stripe price ID
required:
- priceID
responses:
"200":
description: Subscription updated successfully
"400":
$ref: "#/components/responses/bad_request"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: MSP not found or not managed by this reseller
"412":
description: Subscription was recently updated
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/invoices:
get:
summary: List reseller invoices
description: Returns paid invoices for the reseller account
tags:
- MSP
responses:
"200":
description: List of paid invoices
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/InvoiceResponse"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: No invoices found
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/invoices/{invoiceId}/pdf:
get:
summary: Get reseller invoice PDF
description: Returns the Stripe hosted URL for a reseller invoice
tags:
- MSP
parameters:
- in: path
name: invoiceId
required: true
schema:
type: string
description: The Stripe invoice ID
responses:
"200":
description: Invoice PDF URL
content:
application/json:
schema:
$ref: "#/components/schemas/InvoicePDFResponse"
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: Invoice not found
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/msp/reseller/invoices/{invoiceId}/csv:
get:
summary: Get reseller invoice CSV
description: Returns the invoice data as CSV
tags:
- MSP
parameters:
- in: path
name: invoiceId
required: true
schema:
type: string
description: The Stripe invoice ID
responses:
"200":
description: CSV file
content:
text/csv:
schema:
type: string
format: binary
"401":
$ref: "#/components/responses/requires_authentication"
"403":
$ref: "#/components/responses/forbidden"
"404":
description: Invoice not found
"500":
$ref: "#/components/responses/internal_error"
/api/integrations/edr/intune: /api/integrations/edr/intune:
post: post:
tags: tags:

View File

@@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API. // Package api provides primitives to interact with the openapi HTTP API.
// //
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT. // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.7.0 DO NOT EDIT.
package api package api
import ( import (
@@ -13,8 +13,8 @@ import (
) )
const ( const (
BearerAuthScopes = "BearerAuth.Scopes" BearerAuthScopes bearerAuthContextKey = "BearerAuth.Scopes"
TokenAuthScopes = "TokenAuth.Scopes" TokenAuthScopes tokenAuthContextKey = "TokenAuth.Scopes"
) )
// Defines values for AccessRestrictionsCrowdsecMode. // Defines values for AccessRestrictionsCrowdsecMode.
@@ -511,6 +511,7 @@ func (e GroupMinimumIssued) Valid() bool {
// Defines values for IdentityProviderType. // Defines values for IdentityProviderType.
const ( const (
IdentityProviderTypeAdfs IdentityProviderType = "adfs"
IdentityProviderTypeEntra IdentityProviderType = "entra" IdentityProviderTypeEntra IdentityProviderType = "entra"
IdentityProviderTypeGoogle IdentityProviderType = "google" IdentityProviderTypeGoogle IdentityProviderType = "google"
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft" IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
@@ -518,12 +519,13 @@ const (
IdentityProviderTypeOkta IdentityProviderType = "okta" IdentityProviderTypeOkta IdentityProviderType = "okta"
IdentityProviderTypePocketid IdentityProviderType = "pocketid" IdentityProviderTypePocketid IdentityProviderType = "pocketid"
IdentityProviderTypeZitadel IdentityProviderType = "zitadel" IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
IdentityProviderTypeAdfs IdentityProviderType = "adfs"
) )
// Valid indicates whether the value is a known member of the IdentityProviderType enum. // Valid indicates whether the value is a known member of the IdentityProviderType enum.
func (e IdentityProviderType) Valid() bool { func (e IdentityProviderType) Valid() bool {
switch e { switch e {
case IdentityProviderTypeAdfs:
return true
case IdentityProviderTypeEntra: case IdentityProviderTypeEntra:
return true return true
case IdentityProviderTypeGoogle: case IdentityProviderTypeGoogle:
@@ -538,8 +540,6 @@ func (e IdentityProviderType) Valid() bool {
return true return true
case IdentityProviderTypeZitadel: case IdentityProviderTypeZitadel:
return true return true
case IdentityProviderTypeAdfs:
return true
default: default:
return false return false
} }
@@ -671,6 +671,30 @@ func (e JobResponseStatus) Valid() bool {
} }
} }
// Defines values for MSPStatusResponseStatus.
const (
MSPStatusResponseStatusActive MSPStatusResponseStatus = "active"
MSPStatusResponseStatusExisting MSPStatusResponseStatus = "existing"
MSPStatusResponseStatusInvited MSPStatusResponseStatus = "invited"
MSPStatusResponseStatusPending MSPStatusResponseStatus = "pending"
)
// Valid indicates whether the value is a known member of the MSPStatusResponseStatus enum.
func (e MSPStatusResponseStatus) Valid() bool {
switch e {
case MSPStatusResponseStatusActive:
return true
case MSPStatusResponseStatusExisting:
return true
case MSPStatusResponseStatusInvited:
return true
case MSPStatusResponseStatusPending:
return true
default:
return false
}
}
// Defines values for NameserverNsType. // Defines values for NameserverNsType.
const ( const (
NameserverNsTypeUdp NameserverNsType = "udp" NameserverNsTypeUdp NameserverNsType = "udp"
@@ -878,6 +902,27 @@ func (e PolicyRuleUpdateProtocol) Valid() bool {
} }
} }
// Defines values for ResellerMSPResponseStatus.
const (
ResellerMSPResponseStatusActive ResellerMSPResponseStatus = "active"
ResellerMSPResponseStatusExisting ResellerMSPResponseStatus = "existing"
ResellerMSPResponseStatusInvited ResellerMSPResponseStatus = "invited"
)
// Valid indicates whether the value is a known member of the ResellerMSPResponseStatus enum.
func (e ResellerMSPResponseStatus) Valid() bool {
switch e {
case ResellerMSPResponseStatusActive:
return true
case ResellerMSPResponseStatusExisting:
return true
case ResellerMSPResponseStatusInvited:
return true
default:
return false
}
}
// Defines values for ResourceType. // Defines values for ResourceType.
const ( const (
ResourceTypeDomain ResourceType = "domain" ResourceTypeDomain ResourceType = "domain"
@@ -1319,6 +1364,24 @@ func (e GetApiEventsProxyParamsStatus) Valid() bool {
} }
} }
// Defines values for PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue.
const (
PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValueAccept PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue = "accept"
PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValueDecline PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue = "decline"
)
// Valid indicates whether the value is a known member of the PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue enum.
func (e PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue) Valid() bool {
switch e {
case PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValueAccept:
return true
case PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValueDecline:
return true
default:
return false
}
}
// Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue. // Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue.
const ( const (
PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept" PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept"
@@ -1626,7 +1689,9 @@ type Checks struct {
// OsVersionCheck Posture check for the version of operating system // OsVersionCheck Posture check for the version of operating system
OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"`
// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it
// contains any of the peer's local network interface IPs or its public connection (NAT egress) IP,
// so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128.
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
// ProcessCheck Posture Check for binaries exist and are running in the peers system // ProcessCheck Posture Check for binaries exist and are running in the peers system
@@ -1738,6 +1803,21 @@ type CreateOktaScimIntegrationRequest struct {
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
} }
// CreateResellerMSPRequest defines model for CreateResellerMSPRequest.
type CreateResellerMSPRequest struct {
// Domain The domain for the MSP
Domain string `json:"domain"`
// Name The name for the MSP
Name string `json:"name"`
// PriceID Stripe price ID to set up managed subscription for the MSP
PriceID *string `json:"priceID,omitempty"`
// ResellerCustomerId Reseller's internal customer reference for this MSP
ResellerCustomerId *string `json:"reseller_customer_id,omitempty"`
}
// CreateScimIntegrationRequest defines model for CreateScimIntegrationRequest. // CreateScimIntegrationRequest defines model for CreateScimIntegrationRequest.
type CreateScimIntegrationRequest struct { type CreateScimIntegrationRequest struct {
// ConnectorId DEX connector ID for embedded IDP setups // ConnectorId DEX connector ID for embedded IDP setups
@@ -2194,6 +2274,9 @@ type GeoLocationCheck struct {
// GeoLocationCheckAction Action to take upon policy match // GeoLocationCheckAction Action to take upon policy match
type GeoLocationCheckAction string type GeoLocationCheckAction string
// GetResellerMSPsResponse defines model for GetResellerMSPsResponse.
type GetResellerMSPsResponse = []ResellerMSPResponse
// GetTenantsResponse defines model for GetTenantsResponse. // GetTenantsResponse defines model for GetTenantsResponse.
type GetTenantsResponse = []TenantResponse type GetTenantsResponse = []TenantResponse
@@ -2618,6 +2701,51 @@ type Location struct {
CountryCode CountryCode `json:"country_code"` CountryCode CountryCode `json:"country_code"`
} }
// MSPStatusResponse defines model for MSPStatusResponse.
type MSPStatusResponse struct {
// ActivatedAt MSP or Tenant activation timestamp in RFC3339 format
ActivatedAt *string `json:"activated_at,omitempty"`
// Domain MSP domain (present only for MSP accounts)
Domain *string `json:"domain,omitempty"`
// HasReseller Whether the MSP has a reseller (present only for MSP accounts)
HasReseller *bool `json:"has_reseller,omitempty"`
// Id Tenant account ID (present only for tenants)
Id *string `json:"id,omitempty"`
// InvitedAt Tenant invitation timestamp in RFC3339 format (present only for tenants)
InvitedAt *string `json:"invited_at,omitempty"`
// IsReseller Whether the account is a reseller
IsReseller *bool `json:"is_reseller,omitempty"`
// Name MSP name (present only for MSP accounts)
Name *string `json:"name,omitempty"`
// Parent Parent MSP account ID (present only for tenants)
Parent *string `json:"parent,omitempty"`
// ParentDomain Parent MSP domain (present only for tenants)
ParentDomain *string `json:"parent_domain,omitempty"`
// ParentName Parent MSP name (present only for tenants)
ParentName *string `json:"parent_name,omitempty"`
// ParentOwnerEmail Parent MSP owner email
ParentOwnerEmail *string `json:"parent_owner_email,omitempty"`
// ParentOwnerName Parent MSP owner name
ParentOwnerName *string `json:"parent_owner_name,omitempty"`
// Status Tenant status (present only for tenants)
Status *MSPStatusResponseStatus `json:"status,omitempty"`
}
// MSPStatusResponseStatus Tenant status (present only for tenants)
type MSPStatusResponseStatus string
// MinKernelVersionCheck Posture check with the kernel version // MinKernelVersionCheck Posture check with the kernel version
type MinKernelVersionCheck struct { type MinKernelVersionCheck struct {
// MinKernelVersion Minimum acceptable version // MinKernelVersion Minimum acceptable version
@@ -3312,7 +3440,9 @@ type PeerMinimum struct {
Name string `json:"name"` Name string `json:"name"`
} }
// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it
// contains any of the peer's local network interface IPs or its public connection (NAT egress) IP,
// so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128.
type PeerNetworkRangeCheck struct { type PeerNetworkRangeCheck struct {
// Action Action to take upon policy match // Action Action to take upon policy match
Action PeerNetworkRangeCheckAction `json:"action"` Action PeerNetworkRangeCheckAction `json:"action"`
@@ -3771,6 +3901,60 @@ type ProxyCluster struct {
ConnectedProxies int `json:"connected_proxies"` ConnectedProxies int `json:"connected_proxies"`
} }
// ResellerMSPResponse An MSP account managed (or invited) by a reseller.
type ResellerMSPResponse struct {
// ActivatedAt When the MSP account was activated.
ActivatedAt *time.Time `json:"activated_at,omitempty"`
// Domain The MSP account domain.
Domain string `json:"domain"`
// Id The MSP account ID.
Id string `json:"id"`
// InvitedAt When the reseller invited the MSP (set only if invited).
InvitedAt *time.Time `json:"invited_at,omitempty"`
// Name Display name of the MSP.
Name string `json:"name"`
// ResellerCustomerId Reseller's internal customer reference for this MSP.
ResellerCustomerId *string `json:"reseller_customer_id,omitempty"`
// Status Lifecycle status of the reseller↔MSP relationship.
// * `existing` — MSP exists in the system but has not yet been invited.
// * `invited` — Reseller has invited the MSP; MSP user has not accepted.
// * `active` — MSP user accepted; reseller manages the account.
Status ResellerMSPResponseStatus `json:"status"`
// TenantNumber Number of manageable (active or pending) tenants under this MSP.
TenantNumber int `json:"tenant_number"`
}
// ResellerMSPResponseStatus Lifecycle status of the reseller↔MSP relationship.
// * `existing` — MSP exists in the system but has not yet been invited.
// * `invited` — Reseller has invited the MSP; MSP user has not accepted.
// * `active` — MSP user accepted; reseller manages the account.
type ResellerMSPResponseStatus string
// ResellerStatusResponse defines model for ResellerStatusResponse.
type ResellerStatusResponse struct {
// ActivatedAt Reseller activation timestamp in RFC3339 format
ActivatedAt *string `json:"activated_at,omitempty"`
// Domain Reseller domain
Domain *string `json:"domain,omitempty"`
// Name Reseller name
Name *string `json:"name,omitempty"`
// ParentOwnerEmail Reseller owner email
ParentOwnerEmail *string `json:"parent_owner_email,omitempty"`
// ParentOwnerName Reseller owner name
ParentOwnerName *string `json:"parent_owner_name,omitempty"`
}
// Resource defines model for Resource. // Resource defines model for Resource.
type Resource struct { type Resource struct {
// Id ID of the resource // Id ID of the resource
@@ -4471,6 +4655,15 @@ type UpdateOktaScimIntegrationRequest struct {
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
} }
// UpdateResellerMSPRequest Mutable fields the reseller can update on a managed MSP.
type UpdateResellerMSPRequest struct {
// Name New display name for the MSP. Whitespace is trimmed; empty/whitespace-only values are rejected.
Name *string `json:"name,omitempty"`
// ResellerCustomerId New reseller customer reference. Pass empty string to clear.
ResellerCustomerId *string `json:"reseller_customer_id,omitempty"`
}
// UpdateScimIntegrationRequest defines model for UpdateScimIntegrationRequest. // UpdateScimIntegrationRequest defines model for UpdateScimIntegrationRequest.
type UpdateScimIntegrationRequest struct { type UpdateScimIntegrationRequest struct {
// ConnectorId DEX connector ID for embedded IDP setups // ConnectorId DEX connector ID for embedded IDP setups
@@ -4761,6 +4954,12 @@ type ZoneRequest struct {
// Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. // Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided.
type Conflict = ErrorResponse type Conflict = ErrorResponse
// bearerAuthContextKey is the context key for BearerAuth security scheme
type bearerAuthContextKey string
// tokenAuthContextKey is the context key for TokenAuth security scheme
type tokenAuthContextKey string
// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic.
type GetApiEventsNetworkTrafficParams struct { type GetApiEventsNetworkTrafficParams struct {
// Page Page number // Page Page number
@@ -4914,6 +5113,33 @@ type PutApiIntegrationsBillingSubscriptionJSONBody struct {
PriceID *string `json:"priceID,omitempty"` PriceID *string `json:"priceID,omitempty"`
} }
// PostApiIntegrationsMspJSONBody defines parameters for PostApiIntegrationsMsp.
type PostApiIntegrationsMspJSONBody struct {
// Invite The invite code
Invite string `json:"invite"`
}
// PutApiIntegrationsMspResellerMspsIdInviteJSONBody defines parameters for PutApiIntegrationsMspResellerMspsIdInvite.
type PutApiIntegrationsMspResellerMspsIdInviteJSONBody struct {
// Value Accept or decline the invitation
Value PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue `json:"value"`
}
// PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue defines parameters for PutApiIntegrationsMspResellerMspsIdInvite.
type PutApiIntegrationsMspResellerMspsIdInviteJSONBodyValue string
// PostApiIntegrationsMspResellerMspsIdSubscriptionJSONBody defines parameters for PostApiIntegrationsMspResellerMspsIdSubscription.
type PostApiIntegrationsMspResellerMspsIdSubscriptionJSONBody struct {
// PriceID The Stripe price ID for the plan
PriceID string `json:"priceID"`
}
// PutApiIntegrationsMspResellerMspsIdSubscriptionJSONBody defines parameters for PutApiIntegrationsMspResellerMspsIdSubscription.
type PutApiIntegrationsMspResellerMspsIdSubscriptionJSONBody struct {
// PriceID The new Stripe price ID
PriceID string `json:"priceID"`
}
// PutApiIntegrationsMspTenantsIdInviteJSONBody defines parameters for PutApiIntegrationsMspTenantsIdInvite. // PutApiIntegrationsMspTenantsIdInviteJSONBody defines parameters for PutApiIntegrationsMspTenantsIdInvite.
type PutApiIntegrationsMspTenantsIdInviteJSONBody struct { type PutApiIntegrationsMspTenantsIdInviteJSONBody struct {
// Value Accept or decline the invitation. // Value Accept or decline the invitation.
@@ -5058,6 +5284,24 @@ type CreateGoogleIntegrationJSONRequestBody = CreateGoogleIntegrationRequest
// UpdateGoogleIntegrationJSONRequestBody defines body for UpdateGoogleIntegration for application/json ContentType. // UpdateGoogleIntegrationJSONRequestBody defines body for UpdateGoogleIntegration for application/json ContentType.
type UpdateGoogleIntegrationJSONRequestBody = UpdateGoogleIntegrationRequest type UpdateGoogleIntegrationJSONRequestBody = UpdateGoogleIntegrationRequest
// PostApiIntegrationsMspJSONRequestBody defines body for PostApiIntegrationsMsp for application/json ContentType.
type PostApiIntegrationsMspJSONRequestBody PostApiIntegrationsMspJSONBody
// PostApiIntegrationsMspResellerMspsJSONRequestBody defines body for PostApiIntegrationsMspResellerMsps for application/json ContentType.
type PostApiIntegrationsMspResellerMspsJSONRequestBody = CreateResellerMSPRequest
// PutApiIntegrationsMspResellerMspsIdJSONRequestBody defines body for PutApiIntegrationsMspResellerMspsId for application/json ContentType.
type PutApiIntegrationsMspResellerMspsIdJSONRequestBody = UpdateResellerMSPRequest
// PutApiIntegrationsMspResellerMspsIdInviteJSONRequestBody defines body for PutApiIntegrationsMspResellerMspsIdInvite for application/json ContentType.
type PutApiIntegrationsMspResellerMspsIdInviteJSONRequestBody PutApiIntegrationsMspResellerMspsIdInviteJSONBody
// PostApiIntegrationsMspResellerMspsIdSubscriptionJSONRequestBody defines body for PostApiIntegrationsMspResellerMspsIdSubscription for application/json ContentType.
type PostApiIntegrationsMspResellerMspsIdSubscriptionJSONRequestBody PostApiIntegrationsMspResellerMspsIdSubscriptionJSONBody
// PutApiIntegrationsMspResellerMspsIdSubscriptionJSONRequestBody defines body for PutApiIntegrationsMspResellerMspsIdSubscription for application/json ContentType.
type PutApiIntegrationsMspResellerMspsIdSubscriptionJSONRequestBody PutApiIntegrationsMspResellerMspsIdSubscriptionJSONBody
// PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType. // PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType.
type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest

View File

@@ -10,13 +10,15 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/mdlayher/socket" "github.com/mdlayher/socket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -35,6 +37,8 @@ type SharedSocket struct {
conn6 *socket.Conn conn6 *socket.Conn
port int port int
mtu uint16 mtu uint16
routerMux sync.RWMutex
router routing.Router
packetDemux chan rcvdPacket packetDemux chan rcvdPacket
cancel context.CancelFunc cancel context.CancelFunc
} }
@@ -78,6 +82,11 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
} }
}() }()
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
@@ -118,26 +127,31 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
go rawSock.read(rawSock.conn6.Recvfrom) go rawSock.read(rawSock.conn6.Recvfrom)
} }
go rawSock.updateRouter()
return rawSock, nil return rawSock, nil
} }
// resolveSrc returns the source IP the kernel will pick for a packet sent to // updateRouter updates the listener routing table client
// dst by these raw sockets, mirroring the fwmark the kernel will see on send. // this is needed to avoid outdated information across different client networks
func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) { func (s *SharedSocket) updateRouter() {
opts := &netlink.RouteGetOptions{} ticker := time.NewTicker(15 * time.Second)
if nbnet.AdvancedRouting() { defer ticker.Stop()
opts.Mark = nbnet.ControlPlaneMark for {
} select {
routes, err := netlink.RouteGetWithOptions(dst, opts) case <-s.ctx.Done():
return
case <-ticker.C:
router, err := netroute.New()
if err != nil { if err != nil {
return nil, fmt.Errorf("route get %s: %w", dst, err) log.Errorf("Failed to create and update packet router for stunListener: %s", err)
continue
} }
for _, r := range routes { s.routerMux.Lock()
if r.Src != nil { s.router = router
return r.Src, nil s.routerMux.Unlock()
} }
} }
return nil, fmt.Errorf("no source IP for %s", dst)
} }
// LocalAddr returns the local address, preferring IPv4 for backward compatibility. // LocalAddr returns the local address, preferring IPv4 for backward compatibility.
@@ -296,15 +310,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
DstPort: layers.UDPPort(rUDPAddr.Port), DstPort: layers.UDPPort(rUDPAddr.Port),
} }
src, err := s.resolveSrc(rUDPAddr.IP) s.routerMux.RLock()
defer s.routerMux.RUnlock()
_, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil { if err != nil {
return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err) return 0, fmt.Errorf("got an error while checking route, err: %w", err)
} }
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
if conn == nil {
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
}
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err) return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)