Use a 1:1 mapping of netbird client to netbird account

- Add debug endpoint for monitoring netbird clients
- Add types package with AccountID type
- Refactor netbird roundtrip to key clients by AccountID
- Multiple domains can share the same client per account
- Add status notifier for tunnel connection updates
- Add OIDC flags to CLI
- Add tests for netbird client management
This commit is contained in:
Viktor Liu
2026-02-04 14:39:52 +08:00
parent 18cd0f1480
commit ca33849f31
20 changed files with 2270 additions and 182 deletions

View File

@@ -0,0 +1,166 @@
package cmd
import (
"fmt"
"strconv"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/proxy/internal/debug"
)
var (
debugAddr string
jsonOutput bool
// status filters
statusFilterByIPs []string
statusFilterByNames []string
statusFilterByStatus string
statusFilterByConnectionType string
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debug commands for inspecting proxy state",
Long: "Debug commands for inspecting the reverse proxy state via the debug HTTP endpoint.",
}
var debugHealthCmd = &cobra.Command{
Use: "health",
Short: "Show proxy health status",
RunE: runDebugHealth,
SilenceUsage: true,
}
var debugClientsCmd = &cobra.Command{
Use: "clients",
Aliases: []string{"list"},
Short: "List all connected clients",
RunE: runDebugClients,
SilenceUsage: true,
}
var debugStatusCmd = &cobra.Command{
Use: "status <account-id>",
Short: "Show client status",
Args: cobra.ExactArgs(1),
RunE: runDebugStatus,
SilenceUsage: true,
}
var debugSyncCmd = &cobra.Command{
Use: "sync-response <account-id>",
Short: "Show client sync response",
Args: cobra.ExactArgs(1),
RunE: runDebugSync,
SilenceUsage: true,
}
var pingTimeout string
var debugPingCmd = &cobra.Command{
Use: "ping <account-id> <host> [port]",
Short: "TCP ping through a client",
Long: "Perform a TCP ping through a client's network to test connectivity.\nPort defaults to 80 if not specified.",
Args: cobra.RangeArgs(2, 3),
RunE: runDebugPing,
SilenceUsage: true,
}
var debugLogLevelCmd = &cobra.Command{
Use: "loglevel <account-id> <level>",
Short: "Set client log level",
Long: "Set the log level for a client (trace, debug, info, warn, error).",
Args: cobra.ExactArgs(2),
RunE: runDebugLogLevel,
SilenceUsage: true,
}
var debugStartCmd = &cobra.Command{
Use: "start <account-id>",
Short: "Start a client",
Args: cobra.ExactArgs(1),
RunE: runDebugStart,
SilenceUsage: true,
}
var debugStopCmd = &cobra.Command{
Use: "stop <account-id>",
Short: "Stop a client",
Args: cobra.ExactArgs(1),
RunE: runDebugStop,
SilenceUsage: true,
}
func init() {
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
debugStatusCmd.Flags().StringSliceVar(&statusFilterByIPs, "filter-by-ips", nil, "Filter by peer IPs (comma-separated)")
debugStatusCmd.Flags().StringSliceVar(&statusFilterByNames, "filter-by-names", nil, "Filter by peer names (comma-separated)")
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
debugCmd.AddCommand(debugHealthCmd)
debugCmd.AddCommand(debugClientsCmd)
debugCmd.AddCommand(debugStatusCmd)
debugCmd.AddCommand(debugSyncCmd)
debugCmd.AddCommand(debugPingCmd)
debugCmd.AddCommand(debugLogLevelCmd)
debugCmd.AddCommand(debugStartCmd)
debugCmd.AddCommand(debugStopCmd)
rootCmd.AddCommand(debugCmd)
}
func getDebugClient(cmd *cobra.Command) *debug.Client {
return debug.NewClient(debugAddr, jsonOutput, cmd.OutOrStdout())
}
func runDebugHealth(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).Health(cmd.Context())
}
func runDebugClients(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).ListClients(cmd.Context())
}
func runDebugStatus(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).ClientStatus(cmd.Context(), args[0], debug.StatusFilters{
IPs: statusFilterByIPs,
Names: statusFilterByNames,
Status: statusFilterByStatus,
ConnectionType: statusFilterByConnectionType,
})
}
func runDebugSync(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).ClientSyncResponse(cmd.Context(), args[0])
}
func runDebugPing(cmd *cobra.Command, args []string) error {
port := 80
if len(args) > 2 {
p, err := strconv.Atoi(args[2])
if err != nil {
return fmt.Errorf("invalid port: %w", err)
}
port = p
}
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout)
}
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).SetLogLevel(cmd.Context(), args[0], args[1])
}
func runDebugStart(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).StartClient(cmd.Context(), args[0])
}
func runDebugStop(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
}

137
proxy/cmd/proxy/cmd/root.go Normal file
View File

@@ -0,0 +1,137 @@
package cmd
import (
"context"
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme"
"github.com/netbirdio/netbird/proxy"
"github.com/netbirdio/netbird/util"
)
const DefaultManagementURL = "https://api.netbird.io:443"
var (
Version = "dev"
Commit = "unknown"
BuildDate = "unknown"
GoVersion = "unknown"
)
var (
debugLogs bool
mgmtAddr string
addr string
proxyURL string
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
debugEndpoint bool
debugEndpointAddr string
oidcClientID string
oidcClientSecret string
oidcEndpoint string
oidcScopes string
)
var rootCmd = &cobra.Command{
Use: "proxy",
Short: "NetBird reverse proxy server",
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
Version: Version,
RunE: runServer,
}
func init() {
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached")
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
}
// Execute runs the root command.
func Execute() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
// SetVersionInfo sets version information for the CLI.
func SetVersionInfo(version, commit, buildDate, goVersion string) {
Version = version
Commit = commit
BuildDate = buildDate
GoVersion = goVersion
rootCmd.Version = version
rootCmd.SetVersionTemplate("Version: {{.Version}}, Commit: " + Commit + ", BuildDate: " + BuildDate + ", Go: " + GoVersion + "\n")
}
func runServer(cmd *cobra.Command, args []string) error {
level := "error"
if debugLogs {
level = "debug"
}
logger := log.New()
_ = util.InitLogger(logger, level, util.LogConsole)
log.Infof("configured log level: %s", level)
srv := proxy.Server{
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: proxyURL,
CertificateDirectory: certDir,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
OIDCClientId: oidcClientID,
OIDCClientSecret: oidcClientSecret,
OIDCEndpoint: oidcEndpoint,
OIDCScopes: strings.Split(oidcScopes, ","),
}
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
log.Fatal(err)
}
return nil
}
func envBoolOrDefault(key string, def bool) bool {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
parsed, err := strconv.ParseBool(v)
if err != nil {
return def
}
return parsed
}
func envStringOrDefault(key string, def string) string {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
return v
}

View File

@@ -1,22 +1,11 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"runtime"
"strings"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme"
"github.com/netbirdio/netbird/proxy"
"github.com/netbirdio/netbird/proxy/cmd/proxy/cmd"
)
const DefaultManagementURL = "https://api.netbird.io:443"
var (
// Version is the application version (set via ldflags during build)
Version = "dev"
@@ -31,78 +20,7 @@ var (
GoVersion = runtime.Version()
)
func envBoolOrDefault(key string, def bool) bool {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
return v == strings.ToLower("true")
}
func envStringOrDefault(key string, def string) string {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
return v
}
func main() {
var (
version, debug bool
mgmtAddr, addr, url, certDir string
acmeCerts bool
acmeAddr, acmeDir string
oidcId, oidcSecret, oidcEndpoint, oidcScopes string
)
flag.BoolVar(&version, "v", false, "Print version and exit")
flag.BoolVar(&debug, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.")
flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.")
flag.StringVar(&url, "url", envStringOrDefault("NB_PROXY_URL", "proxy.netbird.io"), "The URL at which this proxy will be reached, where CNAME records for proxied endpoints will be directed.")
flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ")
flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.")
flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.")
flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.")
flag.StringVar(&oidcId, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
flag.StringVar(&oidcSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
flag.StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
flag.StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
flag.Parse()
if version {
fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion)
os.Exit(0)
}
// Configure logrus.
level := "error"
if debug {
level = "debug"
}
logger := log.New()
_ = util.InitLogger(logger, level, util.LogConsole)
log.Infof("configured log level: %s", level)
srv := proxy.Server{
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: url,
CertificateDirectory: certDir,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,
OIDCClientId: oidcId,
OIDCClientSecret: oidcSecret,
OIDCEndpoint: oidcEndpoint,
OIDCScopes: strings.Split(oidcScopes, ","),
}
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
log.Fatal(err)
}
cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion)
cmd.Execute()
}

View File

@@ -42,7 +42,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
entry := logEntry{
ID: xid.New().String(),
ServiceId: capturedData.GetServiceId(),
AccountID: capturedData.GetAccountId(),
AccountID: string(capturedData.GetAccountId()),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),

View File

@@ -0,0 +1,307 @@
// Package debug provides HTTP debug endpoints and CLI client for the proxy server.
package debug
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// StatusFilters contains filter options for status queries.
type StatusFilters struct {
IPs []string
Names []string
Status string
ConnectionType string
}
// Client provides CLI access to debug endpoints.
type Client struct {
baseURL string
jsonOutput bool
httpClient *http.Client
out io.Writer
}
// NewClient creates a new debug client.
func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client {
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
baseURL = "http://" + baseURL
}
baseURL = strings.TrimSuffix(baseURL, "/")
return &Client{
baseURL: baseURL,
jsonOutput: jsonOutput,
out: out,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Health fetches the health status.
func (c *Client) Health(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/health", c.printHealth)
}
func (c *Client) printHealth(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"])
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
}
// ListClients fetches the list of all clients.
func (c *Client) ListClients(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/clients", c.printClients)
}
func (c *Client) printClients(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"])
clients, ok := data["clients"].([]any)
if !ok || len(clients) == 0 {
_, _ = fmt.Fprintln(c.out, "No clients connected.")
return
}
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
for _, item := range clients {
c.printClientRow(item)
}
}
func (c *Client) printClientRow(item any) {
client, ok := item.(map[string]any)
if !ok {
return
}
domains := c.extractDomains(client)
hasClient := "no"
if hc, ok := client["has_client"].(bool); ok && hc {
hasClient = "yes"
}
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
client["account_id"],
client["age"],
domains,
hasClient,
)
}
func (c *Client) extractDomains(client map[string]any) string {
d, ok := client["domains"].([]any)
if !ok || len(d) == 0 {
return "-"
}
parts := make([]string, len(d))
for i, domain := range d {
parts[i] = fmt.Sprint(domain)
}
return strings.Join(parts, ", ")
}
// ClientStatus fetches the status of a specific client.
func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error {
params := url.Values{}
if len(filters.IPs) > 0 {
params.Set("filter-by-ips", strings.Join(filters.IPs, ","))
}
if len(filters.Names) > 0 {
params.Set("filter-by-names", strings.Join(filters.Names, ","))
}
if filters.Status != "" {
params.Set("filter-by-status", filters.Status)
}
if filters.ConnectionType != "" {
params.Set("filter-by-connection-type", filters.ConnectionType)
}
path := "/debug/clients/" + url.PathEscape(accountID)
if len(params) > 0 {
path += "?" + params.Encode()
}
return c.fetchAndPrint(ctx, path, c.printClientStatus)
}
func (c *Client) printClientStatus(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
if status, ok := data["status"].(string); ok {
_, _ = fmt.Fprint(c.out, status)
}
}
// ClientSyncResponse fetches the sync response of a specific client.
func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse"
return c.fetchAndPrintJSON(ctx, path)
}
// PingTCP performs a TCP ping through a client.
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
params := url.Values{}
params.Set("host", host)
params.Set("port", fmt.Sprintf("%d", port))
if timeout != "" {
params.Set("timeout", timeout)
}
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
return c.fetchAndPrint(ctx, path, c.printPingResult)
}
func (c *Client) printPingResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
} else {
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
c.printError(data)
}
}
// SetLogLevel sets the log level of a specific client.
func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error {
params := url.Values{}
params.Set("level", level)
path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode())
return c.fetchAndPrint(ctx, path, c.printLogLevelResult)
}
func (c *Client) printLogLevelResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"])
} else {
_, _ = fmt.Fprintln(c.out, "Failed to set log level")
c.printError(data)
}
}
// StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
return c.fetchAndPrint(ctx, path, c.printStartResult)
}
func (c *Client) printStartResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintln(c.out, "Client started")
} else {
_, _ = fmt.Fprintln(c.out, "Failed to start client")
c.printError(data)
}
}
// StopClient stops a specific client.
func (c *Client) StopClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/stop"
return c.fetchAndPrint(ctx, path, c.printStopResult)
}
func (c *Client) printStopResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintln(c.out, "Client stopped")
} else {
_, _ = fmt.Fprintln(c.out, "Failed to stop client")
c.printError(data)
}
}
func (c *Client) printError(data map[string]any) {
if errMsg, ok := data["error"].(string); ok {
_, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg)
}
}
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
data, raw, err := c.fetch(ctx, path)
if err != nil {
return err
}
if c.jsonOutput {
return c.writeJSON(data)
}
if data != nil {
printer(data)
return nil
}
_, _ = fmt.Fprintln(c.out, string(raw))
return nil
}
func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error {
data, raw, err := c.fetch(ctx, path)
if err != nil {
return err
}
if data != nil {
return c.writeJSON(data)
}
_, _ = fmt.Fprintln(c.out, string(raw))
return nil
}
func (c *Client) writeJSON(data map[string]any) error {
enc := json.NewEncoder(c.out)
enc.SetIndent("", " ")
return enc.Encode(data)
}
func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) {
fullURL := c.baseURL + path
if !strings.Contains(path, "format=json") {
if strings.Contains(path, "?") {
fullURL += "&format=json"
} else {
fullURL += "?format=json"
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode >= 400 {
return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return nil, body, nil
}
return data, body, nil
}

View File

@@ -0,0 +1,589 @@
// Package debug provides HTTP debug endpoints for the proxy server.
package debug
import (
"context"
"embed"
"encoding/json"
"fmt"
"html/template"
"net/http"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
nbembed "github.com/netbirdio/netbird/client/embed"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/version"
)
//go:embed templates/*.html
var templateFS embed.FS
const defaultPingTimeout = 10 * time.Second
// formatDuration formats a duration with 2 decimal places using appropriate units.
func formatDuration(d time.Duration) string {
switch {
case d >= time.Hour:
return fmt.Sprintf("%.2fh", d.Hours())
case d >= time.Minute:
return fmt.Sprintf("%.2fm", d.Minutes())
case d >= time.Second:
return fmt.Sprintf("%.2fs", d.Seconds())
case d >= time.Millisecond:
return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000)
case d >= time.Microsecond:
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000)
default:
return fmt.Sprintf("%dns", d.Nanoseconds())
}
}
// ClientProvider provides access to NetBird clients.
type ClientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// Handler provides HTTP debug endpoints.
type Handler struct {
provider ClientProvider
logger *log.Logger
startTime time.Time
templates *template.Template
templateMu sync.RWMutex
}
// NewHandler creates a new debug handler.
func NewHandler(provider ClientProvider, logger *log.Logger) *Handler {
if logger == nil {
logger = log.StandardLogger()
}
h := &Handler{
provider: provider,
logger: logger,
startTime: time.Now(),
}
if err := h.loadTemplates(); err != nil {
logger.Errorf("failed to load embedded templates: %v", err)
}
return h
}
func (h *Handler) loadTemplates() error {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
return fmt.Errorf("parse embedded templates: %w", err)
}
h.templateMu.Lock()
h.templates = tmpl
h.templateMu.Unlock()
return nil
}
func (h *Handler) getTemplates() *template.Template {
h.templateMu.RLock()
defer h.templateMu.RUnlock()
return h.templates
}
// ServeHTTP handles debug requests.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json")
path = strings.TrimSuffix(path, "/json")
switch path {
case "/debug", "/debug/":
h.handleIndex(w, r, wantJSON)
case "/debug/clients":
h.handleListClients(w, r, wantJSON)
case "/debug/health":
h.handleHealth(w, r, wantJSON)
default:
if h.handleClientRoutes(w, r, path, wantJSON) {
return
}
http.NotFound(w, r)
}
}
func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool {
if !strings.HasPrefix(path, "/debug/clients/") {
return false
}
rest := strings.TrimPrefix(path, "/debug/clients/")
parts := strings.SplitN(rest, "/", 2)
accountID := types.AccountID(parts[0])
if len(parts) == 1 {
h.handleClientStatus(w, r, accountID, wantJSON)
return true
}
switch parts[1] {
case "syncresponse":
h.handleClientSyncResponse(w, r, accountID, wantJSON)
case "tools":
h.handleClientTools(w, r, accountID)
case "pingtcp":
h.handlePingTCP(w, r, accountID)
case "loglevel":
h.handleLogLevel(w, r, accountID)
case "start":
h.handleClientStart(w, r, accountID)
case "stop":
h.handleClientStop(w, r, accountID)
default:
return false
}
return true
}
type indexData struct {
Version string
Uptime string
ClientCount int
TotalDomains int
Clients []clientData
}
type clientData struct {
AccountID string
Domains string
Age string
Status string
}
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
clients := h.provider.ListClientsForDebug()
totalDomains := 0
for _, info := range clients {
totalDomains += info.DomainCount
}
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, info := range clients {
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"total_domains": totalDomains,
"clients": clientsJSON,
})
return
}
data := indexData{
Version: version.NetbirdVersion(),
Uptime: time.Since(h.startTime).Round(time.Second).String(),
ClientCount: len(clients),
TotalDomains: totalDomains,
Clients: make([]clientData, 0, len(clients)),
}
for _, info := range clients {
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
}
status := "No client"
if info.HasClient {
status = "Active"
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})
}
h.renderTemplate(w, "index", data)
}
type clientsData struct {
Uptime string
Clients []clientData
}
func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
clients := h.provider.ListClientsForDebug()
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, info := range clients {
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
})
return
}
data := clientsData{
Uptime: time.Since(h.startTime).Round(time.Second).String(),
Clients: make([]clientData, 0, len(clients)),
}
for _, info := range clients {
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
}
status := "No client"
if info.HasClient {
status = "Active"
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})
}
h.renderTemplate(w, "clients", data)
}
type clientDetailData struct {
AccountID string
ActiveTab string
Content string
}
func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
fullStatus, err := client.Status()
if err != nil {
http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError)
return
}
// Parse filter parameters
query := r.URL.Query()
statusFilter := query.Get("filter-by-status")
connectionTypeFilter := query.Get("filter-by-connection-type")
var prefixNamesFilter []string
var prefixNamesFilterMap map[string]struct{}
if names := query.Get("filter-by-names"); names != "" {
prefixNamesFilter = strings.Split(names, ",")
prefixNamesFilterMap = make(map[string]struct{})
for _, name := range prefixNamesFilter {
prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
}
}
var ipsFilterMap map[string]struct{}
if ips := query.Get("filter-by-ips"); ips != "" {
ipsFilterMap = make(map[string]struct{})
for _, ip := range strings.Split(ips, ",") {
ipsFilterMap[strings.TrimSpace(ip)] = struct{}{}
}
}
pbStatus := nbstatus.ToProtoFullStatus(fullStatus)
overview := nbstatus.ConvertToStatusOutputOverview(
pbStatus,
false,
version.NetbirdVersion(),
statusFilter,
prefixNamesFilter,
prefixNamesFilterMap,
ipsFilterMap,
connectionTypeFilter,
"",
)
if wantJSON {
h.writeJSON(w, map[string]interface{}{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
return
}
data := clientDetailData{
AccountID: string(accountID),
ActiveTab: "status",
Content: overview.FullDetailSummary(),
}
h.renderTemplate(w, "clientDetail", data)
}
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError)
return
}
if syncResp == nil {
http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound)
return
}
opts := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
Indent: " ",
AllowPartial: true,
}
jsonBytes, err := opts.Marshal(syncResp)
if err != nil {
http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError)
return
}
if wantJSON {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonBytes)
return
}
data := clientDetailData{
AccountID: string(accountID),
ActiveTab: "syncresponse",
Content: string(jsonBytes),
}
h.renderTemplate(w, "clientDetail", data)
}
type toolsData struct {
AccountID string
}
func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) {
_, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
data := toolsData{
AccountID: string(accountID),
}
h.renderTemplate(w, "tools", data)
}
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
return
}
timeout := defaultPingTimeout
if t := r.URL.Query().Get("timeout"); t != "" {
if d, err := time.ParseDuration(t); err == nil {
timeout = d
}
}
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
address := fmt.Sprintf("%s:%d", host, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"host": host,
"port": port,
"error": err.Error(),
})
return
}
if err := conn.Close(); err != nil {
h.logger.Debugf("close tcp ping connection: %v", err)
}
latency := time.Since(start)
h.writeJSON(w, map[string]interface{}{
"success": true,
"host": host,
"port": port,
"latency_ms": latency.Milliseconds(),
"latency": formatDuration(latency),
})
}
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
level := r.URL.Query().Get("level")
if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
return
}
if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"level": level,
})
}
const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
defer cancel()
if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"message": "client started",
})
}
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
defer cancel()
if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"message": "client stopped",
})
}
type healthData struct {
Uptime string
}
func (h *Handler) handleHealth(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
if wantJSON {
h.writeJSON(w, map[string]interface{}{
"status": "ok",
"uptime": time.Since(h.startTime).Round(10 * time.Millisecond).String(),
})
return
}
data := healthData{
Uptime: time.Since(h.startTime).Round(time.Second).String(),
}
h.renderTemplate(w, "health", data)
}
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates()
if tmpl == nil {
http.Error(w, "Templates not loaded", http.StatusInternalServerError)
return
}
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
h.logger.Errorf("execute template %s: %v", name, err)
http.Error(w, "Template error", http.StatusInternalServerError)
}
}
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
if err := enc.Encode(v); err != nil {
h.logger.Errorf("encode JSON response: %v", err)
}
}

View File

@@ -0,0 +1,101 @@
{{define "style"}}
body {
font-family: monospace;
margin: 20px;
background: #1a1a1a;
color: #eee;
}
a {
color: #6cf;
}
h1, h2, h3 {
color: #fff;
}
.info {
color: #aaa;
}
table {
border-collapse: collapse;
margin: 10px 0;
}
th, td {
border: 1px solid #444;
padding: 8px;
text-align: left;
}
th {
background: #333;
}
.nav {
margin-bottom: 20px;
}
.nav a {
margin-right: 15px;
padding: 8px 16px;
background: #333;
text-decoration: none;
border-radius: 4px;
}
.nav a.active {
background: #6cf;
color: #000;
}
pre {
background: #222;
padding: 15px;
border-radius: 4px;
overflow-x: auto;
white-space: pre-wrap;
}
input, select, textarea {
background: #333;
color: #eee;
border: 1px solid #555;
padding: 8px;
border-radius: 4px;
font-family: monospace;
}
input:focus, select:focus, textarea:focus {
outline: none;
border-color: #6cf;
}
button {
background: #6cf;
color: #000;
border: none;
padding: 8px 16px;
border-radius: 4px;
cursor: pointer;
font-family: monospace;
}
button:hover {
background: #5be;
}
button:disabled {
background: #555;
color: #888;
cursor: not-allowed;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
color: #aaa;
}
.form-row {
display: flex;
gap: 10px;
align-items: flex-end;
}
.result {
margin-top: 20px;
}
.success {
color: #5f5;
}
.error {
color: #f55;
}
{{end}}

View File

@@ -0,0 +1,19 @@
{{define "clientDetail"}}
<!DOCTYPE html>
<html>
<head>
<title>Client {{.AccountID}}</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>Client: {{.AccountID}}</h1>
<div class="nav">
<a href="/debug">&larr; Back</a>
<a href="/debug/clients/{{.AccountID}}/tools"{{if eq .ActiveTab "tools"}} class="active"{{end}}>Tools</a>
<a href="/debug/clients/{{.AccountID}}"{{if eq .ActiveTab "status"}} class="active"{{end}}>Status</a>
<a href="/debug/clients/{{.AccountID}}/syncresponse"{{if eq .ActiveTab "syncresponse"}} class="active"{{end}}>Sync Response</a>
</div>
<pre>{{.Content}}</pre>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,33 @@
{{define "clients"}}
<!DOCTYPE html>
<html>
<head>
<title>Clients</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>All Clients</h1>
<p class="info">Uptime: {{.Uptime}} | <a href="/debug">&larr; Back</a></p>
{{if .Clients}}
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>
{{end}}
</table>
{{else}}
<p>No clients connected</p>
{{end}}
</body>
</html>
{{end}}

View File

@@ -0,0 +1,14 @@
{{define "health"}}
<!DOCTYPE html>
<html>
<head>
<title>Health</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>OK</h1>
<p>Uptime: {{.Uptime}}</p>
<p><a href="/debug">&larr; Back</a></p>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,40 @@
{{define "index"}}
<!DOCTYPE html>
<html>
<head>
<title>NetBird Proxy Debug</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>NetBird Proxy Debug</h1>
<p class="info">Version: {{.Version}} | Uptime: {{.Uptime}}</p>
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
{{if .Clients}}
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>
{{end}}
</table>
{{else}}
<p>No clients connected</p>
{{end}}
<h2>Endpoints</h2>
<ul>
<li><a href="/debug/clients">/debug/clients</a> - all clients detail</li>
<li><a href="/debug/health">/debug/health</a> - health check</li>
</ul>
<p class="info">Add ?format=json or /json suffix for JSON output</p>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,142 @@
{{define "tools"}}
<!DOCTYPE html>
<html>
<head>
<title>Client {{.AccountID}} - Tools</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>Client: {{.AccountID}}</h1>
<div class="nav">
<a href="/debug">&larr; Back</a>
<a href="/debug/clients/{{.AccountID}}/tools" class="active">Tools</a>
<a href="/debug/clients/{{.AccountID}}">Status</a>
<a href="/debug/clients/{{.AccountID}}/syncresponse">Sync Response</a>
</div>
<h2>Client Control</h2>
<div class="form-row">
<div class="form-group">
<label>&nbsp;</label>
<button onclick="startClient()">Start</button>
</div>
<div class="form-group">
<label>&nbsp;</label>
<button onclick="stopClient()">Stop</button>
</div>
</div>
<div id="client-result" class="result"></div>
<h2>Log Level</h2>
<div class="form-row">
<div class="form-group">
<label>Level</label>
<select id="log-level" style="width: 120px;">
<option value="trace">trace</option>
<option value="debug">debug</option>
<option value="info">info</option>
<option value="warn" selected>warn</option>
<option value="error">error</option>
</select>
</div>
<div class="form-group">
<label>&nbsp;</label>
<button onclick="setLogLevel()">Set Level</button>
</div>
</div>
<div id="log-result" class="result"></div>
<h2>TCP Ping</h2>
<div class="form-row">
<div class="form-group">
<label>Host</label>
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
</div>
<div class="form-group">
<label>Port</label>
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
</div>
<div class="form-group">
<label>&nbsp;</label>
<button onclick="doTcpPing()">Connect</button>
</div>
</div>
<div id="tcp-result" class="result"></div>
<script>
const accountID = "{{.AccountID}}";
async function startClient() {
const resultDiv = document.getElementById('client-result');
resultDiv.innerHTML = '<span class="info">Starting client...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/start');
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function stopClient() {
const resultDiv = document.getElementById('client-result');
resultDiv.innerHTML = '<span class="info">Stopping client...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/stop');
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function setLogLevel() {
const level = document.getElementById('log-level').value;
const resultDiv = document.getElementById('log-result');
resultDiv.innerHTML = '<span class="info">Setting log level...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/loglevel?level=' + level);
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ Log level set to: ' + data.level + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function doTcpPing() {
const host = document.getElementById('tcp-host').value;
const port = document.getElementById('tcp-port').value;
if (!host || !port) {
alert('Host and port required');
return;
}
const resultDiv = document.getElementById('tcp-result');
resultDiv.innerHTML = '<span class="info">Connecting...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/pingtcp?host=' + encodeURIComponent(host) + '&port=' + port);
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.host + ':' + data.port + ' connected in ' + data.latency + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.host + ':' + data.port + ': ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
</script>
</body>
</html>
{{end}}

View File

@@ -3,6 +3,8 @@ package proxy
import (
"context"
"sync"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type requestContextKey string
@@ -18,7 +20,7 @@ const (
type CapturedData struct {
mu sync.RWMutex
ServiceId string
AccountId string
AccountId types.AccountID
}
// SetServiceId safely sets the service ID
@@ -36,14 +38,14 @@ func (c *CapturedData) GetServiceId() string {
}
// SetAccountId safely sets the account ID
func (c *CapturedData) SetAccountId(accountId string) {
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
c.mu.Lock()
defer c.mu.Unlock()
c.AccountId = accountId
}
// GetAccountId safely gets the account ID
func (c *CapturedData) GetAccountId() string {
func (c *CapturedData) GetAccountId() types.AccountID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AccountId
@@ -76,13 +78,13 @@ func ServiceIdFromContext(ctx context.Context) string {
}
return serviceId
}
func withAccountId(ctx context.Context, accountId string) context.Context {
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) string {
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(string)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}

View File

@@ -4,6 +4,8 @@ import (
"net/http"
"net/http/httputil"
"sync"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
)
type ReverseProxy struct {
@@ -36,8 +38,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), serviceId)
// Set the accountId in the context for later retrieval.
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct

View File

@@ -6,16 +6,18 @@ import (
"net/url"
"sort"
"strings"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type Mapping struct {
ID string
AccountID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, string, bool) {
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
p.mappingsMux.RLock()
if p.mappings == nil {
p.mappingsMux.RUnlock()
@@ -27,10 +29,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
}
defer p.mappingsMux.RUnlock()
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
m, exists := p.mappings[host]
if !exists {
return nil, "", "", false

View File

@@ -4,18 +4,39 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
)
const deviceNamePrefix = "ingress-"
const deviceNamePrefix = "ingress-proxy-"
// ErrNoAccountID is returned when a request context is missing the account ID.
var ErrNoAccountID = errors.New("no account ID in request context")
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
reverseProxyID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
@@ -23,147 +44,389 @@ type statusNotifier interface {
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
type NetBird struct {
mgmtAddr string
proxyID string
logger *log.Logger
clientsMux sync.RWMutex
clients map[string]*embed.Client
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
}
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
// NewNetBird creates a new NetBird transport.
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
logger: logger,
clients: make(map[string]*embed.Client),
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,
}
}
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created using the provided setup key. Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID}
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
n.initLogOnce.Do(func() {
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
}
})
wgPort := 0
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + domain,
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.mgmtAddr,
SetupKey: key,
LogOutput: io.Discard,
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &wgPort,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
entry = &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}},
transport: &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
createdAt: time.Now(),
started: false,
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background, if this fails
// then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = client.Start(startCtx)
switch {
case errors.Is(err, context.DeadlineExceeded):
n.logger.Debug("netbird client timed out")
// This is not ideal, but we will try again later.
return
case err != nil:
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Debug("netbird client start timed out, will retry on first request")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Error("failed to start netbird client")
}
return
}
// Notify management that tunnel is now active
// Mark client as started and notify all registered domains
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
// Copy domain info while holding lock
var domainsToNotify []struct {
domain domain.Domain
reverseProxyID string
}
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, struct {
domain domain.Domain
reverseProxyID string
}{domain: dom, reverseProxyID: info.reverseProxyID})
}
}
n.clientsMux.Unlock()
// Notify all domains that they're connected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
} else {
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
for _, domInfo := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).Info("notified management about tunnel connection")
}
}
}
}()
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.clients[domain] = client
return nil
}
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
n.clientsMux.RLock()
client, exists := n.clients[domain]
n.clientsMux.RUnlock()
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
// Mission failed successfully!
n.clientsMux.Unlock()
return nil
}
if err := client.Stop(ctx); err != nil {
return fmt.Errorf("stop netbird client: %w", err)
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
n.clientsMux.Unlock()
return nil
}
// Notify management that tunnel is disconnected
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
} else {
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
delete(n.clients, domain)
transport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
accountID := AccountIDFromContext(req.Context())
if accountID == "" {
return nil, ErrNoAccountID
}
// Copy references while holding lock, then unlock early to avoid blocking
// other requests during the potentially slow RoundTrip.
n.clientsMux.RLock()
client, exists := n.clients[host]
// Immediately unlock after retrieval here rather than defer to avoid
// the call to client.Do blocking other clients being used whilst one
// is in use.
n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return nil, fmt.Errorf("no peer connection found for host: %s", host)
n.clientsMux.RUnlock()
return nil, fmt.Errorf("no peer connection found for account: %s", accountID)
}
client := entry.client
transport := entry.transport
n.clientsMux.RUnlock()
// Attempt to start the client, if the client is already running then
// it will return an error that we ignore, if this hits a timeout then
// this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
defer cancel()
err = client.Start(startCtx)
switch {
case errors.Is(err, embed.ErrClientAlreadyStarted):
break
case err != nil:
return nil, fmt.Errorf("start netbird client: %w", err)
if err := client.Start(startCtx); err != nil {
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
return nil, fmt.Errorf("start netbird client: %w", err)
}
}
n.logger.WithFields(log.Fields{
"host": host,
"account_id": accountID,
"host": req.Host,
"url": req.URL.String(),
"requestURI": req.RequestURI,
"method": req.Method,
}).Debug("running roundtrip for peer connection")
// Create a new transport using the client dialer and perform the roundtrip.
// We do this instead of using the client HTTPClient to avoid issues around
// client request validation that do not work with the reverse proxied
// requests.
// Other values are simply copied from the http.DefaultTransport which the
// standard reverse proxy implementation would have used.
// TODO: tune this transport for our needs.
return (&http.Transport{
DialContext: client.DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}).RoundTrip(req)
return transport.RoundTrip(req)
}
// StopAll stops all clients.
func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
var merr *multierror.Error
for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client during shutdown")
merr = multierror.Append(merr, err)
}
}
maps.Clear(n.clients)
return nberrors.FormatErrorOrNil(merr)
}
// HasClient returns true if there is a client for the given account.
func (n *NetBird) HasClient(accountID types.AccountID) bool {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
_, exists := n.clients[accountID]
return exists
}
// DomainCount returns the number of domains registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
}
// ClientCount returns the total number of active clients.
func (n *NetBird) ClientCount() int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
return len(n.clients)
}
// GetClient returns the embed.Client for the given account ID.
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return nil, false
}
return entry.client, true
}
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
}
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)
}
// AccountIDFromContext retrieves the account ID from the context.
func AccountIDFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIDContextKey{})
if v == nil {
return ""
}
accountID, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountID
}

View File

@@ -0,0 +1,247 @@
package roundtrip
import (
"context"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird("http://invalid.test:9999", "test-proxy", nil, nil)
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Initially no client exists.
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
// Add first domain - this should create a new client.
// Note: This will fail to actually connect since we use an invalid URL,
// but the client entry should still be created.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add first domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.DomainCount(accountID))
// Add second domain for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
// Add third domain.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
// Still only one client.
assert.True(t, nb.HasClient(accountID))
}
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
// Add domain for account 1.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Add domain for account 2.
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
require.NoError(t, err)
// Both accounts should have their own clients.
assert.True(t, nb.HasClient(account1), "account1 should have client")
assert.True(t, nb.HasClient(account2), "account2 should have client")
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
}
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add multiple domains.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID))
// Remove one domain - client should remain.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
// Remove another domain - client should still remain.
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add single domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
// Remove the only domain - client should be removed.
// Note: Stop() may fail since the client never actually connected,
// but the entry should still be removed from the map.
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
// After removing all domains, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
}
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Removing from non-existent account should not error.
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
assert.NoError(t, err, "removing from non-existent account should not error")
}
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add one domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Remove non-existent domain - should not affect existing domain.
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
require.NoError(t, err)
// Original domain should still be registered.
assert.True(t, nb.HasClient(accountID))
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
}
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
ctx := context.Background()
accountID := types.AccountID("test-account")
// Initially no account ID in context.
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should be empty when not set")
// Add account ID to context.
ctx = WithAccountID(ctx, accountID)
retrieved = AccountIDFromContext(ctx)
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
}
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
// Create context with wrong type for account ID key.
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should return empty for wrong type")
}
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
account3 := types.AccountID("account-3")
// Add domains for multiple accounts.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
// Stop all clients.
// Note: StopAll may return errors since clients never actually connected,
// but the clients should still be removed from the map.
_ = nb.StopAll(context.Background())
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
assert.False(t, nb.HasClient(account1), "account1 should not have client")
assert.False(t, nb.HasClient(account2), "account2 should not have client")
assert.False(t, nb.HasClient(account3), "account3 should not have client")
}
func TestNetBird_ClientCount(t *testing.T) {
nb := mockNetBird()
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
// Add clients for different accounts.
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.ClientCount())
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount())
// Adding domain to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
}
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
nb := mockNetBird()
// Create a request without account ID in context.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
// RoundTrip should fail because no account ID in context.
_, err = nb.RoundTrip(req)
require.ErrorIs(t, err, ErrNoAccountID)
}
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Create a request with account ID but no client exists.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
req = req.WithContext(WithAccountID(req.Context(), accountID))
// RoundTrip should fail because no client for this account.
_, err = nb.RoundTrip(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}

View File

@@ -0,0 +1,5 @@
// Package types defines common types used across the proxy package.
package types
// AccountID represents a unique identifier for a NetBird account.
type AccountID string

View File

@@ -31,8 +31,11 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -45,6 +48,7 @@ type Server struct {
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
// Mostly used for debugging on management.
startTime time.Time
@@ -62,6 +66,11 @@ type Server struct {
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// DebugEndpointEnabled enables the debug HTTP endpoint.
DebugEndpointEnabled bool
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
DebugEndpointAddress string
}
// NotifyStatus sends a status update to management about tunnel connectivity
@@ -148,7 +157,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger, s)
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.Logger, s)
// When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{}
@@ -204,6 +213,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Configure Access logs to management server.
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
if s.DebugEndpointEnabled {
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.Logger)
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
}
go func() {
s.Logger.WithField("address", debugAddr).Info("starting debug endpoint")
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
defer func() {
if err := s.debug.Close(); err != nil {
s.Logger.Debugf("debug endpoint close: %v", err)
}
}()
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.netbird.StopAll(stopCtx); err != nil {
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
}
}()
// Finally, start the reverse proxy.
s.https = &http.Server{
Addr: addr,
@@ -306,15 +343,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
}
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
domain := mapping.GetDomain()
accountID := mapping.GetAccountId()
d := domain.Domain(mapping.GetDomain())
accountID := types.AccountID(mapping.GetAccountId())
reverseProxyID := mapping.GetId()
if err := s.netbird.AddPeer(ctx, domain, mapping.GetSetupKey(), accountID, reverseProxyID); err != nil {
return fmt.Errorf("create peer for domain %q: %w", domain, err)
if err := s.netbird.AddPeer(ctx, accountID, d, mapping.GetSetupKey(), reverseProxyID); err != nil {
return fmt.Errorf("create peer for domain %q: %w", d, err)
}
if s.acme != nil {
s.acme.AddDomain(domain, accountID, reverseProxyID)
s.acme.AddDomain(string(d), string(accountID), reverseProxyID)
}
// Pass the mapping through to the update function to avoid duplicating the
@@ -360,10 +397,13 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
}
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain(), mapping.GetAccountId(), mapping.GetId()); err != nil {
d := domain.Domain(mapping.GetDomain())
accountID := types.AccountID(mapping.GetAccountId())
if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil {
s.Logger.WithFields(log.Fields{
"domain": mapping.GetDomain(),
"error": err,
"account_id": accountID,
"domain": d,
"error": err,
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
}
if s.acme != nil {
@@ -392,8 +432,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}
return proxy.Mapping{
ID: mapping.GetId(),
AccountID: mapping.AccountId,
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
Paths: paths,
}
}
// debugEndpointAddr returns the address for the debug endpoint.
// If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string {
if addr == "" {
return "localhost:8444"
}
return addr
}

48
proxy/server_test.go Normal file
View File

@@ -0,0 +1,48 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDebugEndpointDisabledByDefault(t *testing.T) {
s := &Server{}
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
}
func TestDebugEndpointAddr(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty defaults to localhost",
input: "",
expected: "localhost:8444",
},
{
name: "explicit localhost preserved",
input: "localhost:9999",
expected: "localhost:9999",
},
{
name: "explicit address preserved",
input: "0.0.0.0:8444",
expected: "0.0.0.0:8444",
},
{
name: "127.0.0.1 preserved",
input: "127.0.0.1:8444",
expected: "127.0.0.1:8444",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := debugEndpointAddr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}