diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go new file mode 100644 index 000000000..86172d78b --- /dev/null +++ b/proxy/cmd/proxy/cmd/debug.go @@ -0,0 +1,166 @@ +package cmd + +import ( + "fmt" + "strconv" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/proxy/internal/debug" +) + +var ( + debugAddr string + jsonOutput bool + + // status filters + statusFilterByIPs []string + statusFilterByNames []string + statusFilterByStatus string + statusFilterByConnectionType string +) + +var debugCmd = &cobra.Command{ + Use: "debug", + Short: "Debug commands for inspecting proxy state", + Long: "Debug commands for inspecting the reverse proxy state via the debug HTTP endpoint.", +} + +var debugHealthCmd = &cobra.Command{ + Use: "health", + Short: "Show proxy health status", + RunE: runDebugHealth, + SilenceUsage: true, +} + +var debugClientsCmd = &cobra.Command{ + Use: "clients", + Aliases: []string{"list"}, + Short: "List all connected clients", + RunE: runDebugClients, + SilenceUsage: true, +} + +var debugStatusCmd = &cobra.Command{ + Use: "status ", + Short: "Show client status", + Args: cobra.ExactArgs(1), + RunE: runDebugStatus, + SilenceUsage: true, +} + +var debugSyncCmd = &cobra.Command{ + Use: "sync-response ", + Short: "Show client sync response", + Args: cobra.ExactArgs(1), + RunE: runDebugSync, + SilenceUsage: true, +} + +var pingTimeout string + +var debugPingCmd = &cobra.Command{ + Use: "ping [port]", + Short: "TCP ping through a client", + Long: "Perform a TCP ping through a client's network to test connectivity.\nPort defaults to 80 if not specified.", + Args: cobra.RangeArgs(2, 3), + RunE: runDebugPing, + SilenceUsage: true, +} + +var debugLogLevelCmd = &cobra.Command{ + Use: "loglevel ", + Short: "Set client log level", + Long: "Set the log level for a client (trace, debug, info, warn, error).", + Args: cobra.ExactArgs(2), + RunE: runDebugLogLevel, + SilenceUsage: true, +} + +var debugStartCmd = &cobra.Command{ + Use: "start ", + Short: "Start a client", + Args: cobra.ExactArgs(1), + RunE: runDebugStart, + SilenceUsage: true, +} + +var debugStopCmd = &cobra.Command{ + Use: "stop ", + Short: "Stop a client", + Args: cobra.ExactArgs(1), + RunE: runDebugStop, + SilenceUsage: true, +} + +func init() { + debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") + debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") + + debugStatusCmd.Flags().StringSliceVar(&statusFilterByIPs, "filter-by-ips", nil, "Filter by peer IPs (comma-separated)") + debugStatusCmd.Flags().StringSliceVar(&statusFilterByNames, "filter-by-names", nil, "Filter by peer names (comma-separated)") + debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") + debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") + + debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + + debugCmd.AddCommand(debugHealthCmd) + debugCmd.AddCommand(debugClientsCmd) + debugCmd.AddCommand(debugStatusCmd) + debugCmd.AddCommand(debugSyncCmd) + debugCmd.AddCommand(debugPingCmd) + debugCmd.AddCommand(debugLogLevelCmd) + debugCmd.AddCommand(debugStartCmd) + debugCmd.AddCommand(debugStopCmd) + + rootCmd.AddCommand(debugCmd) +} + +func getDebugClient(cmd *cobra.Command) *debug.Client { + return debug.NewClient(debugAddr, jsonOutput, cmd.OutOrStdout()) +} + +func runDebugHealth(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).Health(cmd.Context()) +} + +func runDebugClients(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).ListClients(cmd.Context()) +} + +func runDebugStatus(cmd *cobra.Command, args []string) error { + return getDebugClient(cmd).ClientStatus(cmd.Context(), args[0], debug.StatusFilters{ + IPs: statusFilterByIPs, + Names: statusFilterByNames, + Status: statusFilterByStatus, + ConnectionType: statusFilterByConnectionType, + }) +} + +func runDebugSync(cmd *cobra.Command, args []string) error { + return getDebugClient(cmd).ClientSyncResponse(cmd.Context(), args[0]) +} + +func runDebugPing(cmd *cobra.Command, args []string) error { + port := 80 + if len(args) > 2 { + p, err := strconv.Atoi(args[2]) + if err != nil { + return fmt.Errorf("invalid port: %w", err) + } + port = p + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) +} + +func runDebugLogLevel(cmd *cobra.Command, args []string) error { + return getDebugClient(cmd).SetLogLevel(cmd.Context(), args[0], args[1]) +} + +func runDebugStart(cmd *cobra.Command, args []string) error { + return getDebugClient(cmd).StartClient(cmd.Context(), args[0]) +} + +func runDebugStop(cmd *cobra.Command, args []string) error { + return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) +} diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go new file mode 100644 index 000000000..3bff9a834 --- /dev/null +++ b/proxy/cmd/proxy/cmd/root.go @@ -0,0 +1,137 @@ +package cmd + +import ( + "context" + "os" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "golang.org/x/crypto/acme" + + "github.com/netbirdio/netbird/proxy" + "github.com/netbirdio/netbird/util" +) + +const DefaultManagementURL = "https://api.netbird.io:443" + +var ( + Version = "dev" + Commit = "unknown" + BuildDate = "unknown" + GoVersion = "unknown" +) + +var ( + debugLogs bool + mgmtAddr string + addr string + proxyURL string + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + debugEndpoint bool + debugEndpointAddr string + oidcClientID string + oidcClientSecret string + oidcEndpoint string + oidcScopes string +) + +var rootCmd = &cobra.Command{ + Use: "proxy", + Short: "NetBird reverse proxy server", + Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.", + Version: Version, + RunE: runServer, +} + +func init() { + rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") + rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to") + rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on") + rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached") + rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates") + rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges") + rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges") + rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory") + rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint") + rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint") + rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication") + rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication") + rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication") + rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated") +} + +// Execute runs the root command. +func Execute() { + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } +} + +// SetVersionInfo sets version information for the CLI. +func SetVersionInfo(version, commit, buildDate, goVersion string) { + Version = version + Commit = commit + BuildDate = buildDate + GoVersion = goVersion + rootCmd.Version = version + rootCmd.SetVersionTemplate("Version: {{.Version}}, Commit: " + Commit + ", BuildDate: " + BuildDate + ", Go: " + GoVersion + "\n") +} + +func runServer(cmd *cobra.Command, args []string) error { + level := "error" + if debugLogs { + level = "debug" + } + logger := log.New() + + _ = util.InitLogger(logger, level, util.LogConsole) + + log.Infof("configured log level: %s", level) + + srv := proxy.Server{ + Logger: logger, + Version: Version, + ManagementAddress: mgmtAddr, + ProxyURL: proxyURL, + CertificateDirectory: certDir, + GenerateACMECertificates: acmeCerts, + ACMEChallengeAddress: acmeAddr, + ACMEDirectory: acmeDir, + DebugEndpointEnabled: debugEndpoint, + DebugEndpointAddress: debugEndpointAddr, + OIDCClientId: oidcClientID, + OIDCClientSecret: oidcClientSecret, + OIDCEndpoint: oidcEndpoint, + OIDCScopes: strings.Split(oidcScopes, ","), + } + + if err := srv.ListenAndServe(context.TODO(), addr); err != nil { + log.Fatal(err) + } + return nil +} + +func envBoolOrDefault(key string, def bool) bool { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + parsed, err := strconv.ParseBool(v) + if err != nil { + return def + } + return parsed +} + +func envStringOrDefault(key string, def string) string { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + return v +} diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index 2eb0ac01d..14e540a2e 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -1,22 +1,11 @@ package main import ( - "context" - "flag" - "fmt" - "os" "runtime" - "strings" - "github.com/netbirdio/netbird/util" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/acme" - - "github.com/netbirdio/netbird/proxy" + "github.com/netbirdio/netbird/proxy/cmd/proxy/cmd" ) -const DefaultManagementURL = "https://api.netbird.io:443" - var ( // Version is the application version (set via ldflags during build) Version = "dev" @@ -31,78 +20,7 @@ var ( GoVersion = runtime.Version() ) -func envBoolOrDefault(key string, def bool) bool { - v, exists := os.LookupEnv(key) - if !exists { - return def - } - return v == strings.ToLower("true") -} - -func envStringOrDefault(key string, def string) string { - v, exists := os.LookupEnv(key) - if !exists { - return def - } - return v -} - func main() { - var ( - version, debug bool - mgmtAddr, addr, url, certDir string - acmeCerts bool - acmeAddr, acmeDir string - oidcId, oidcSecret, oidcEndpoint, oidcScopes string - ) - - flag.BoolVar(&version, "v", false, "Print version and exit") - flag.BoolVar(&debug, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") - flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.") - flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.") - flag.StringVar(&url, "url", envStringOrDefault("NB_PROXY_URL", "proxy.netbird.io"), "The URL at which this proxy will be reached, where CNAME records for proxied endpoints will be directed.") - flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ") - flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.") - flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.") - flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.") - flag.StringVar(&oidcId, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication") - flag.StringVar(&oidcSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication") - flag.StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication") - flag.StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated") - flag.Parse() - - if version { - fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion) - os.Exit(0) - } - - // Configure logrus. - level := "error" - if debug { - level = "debug" - } - logger := log.New() - - _ = util.InitLogger(logger, level, util.LogConsole) - - log.Infof("configured log level: %s", level) - - srv := proxy.Server{ - Logger: logger, - Version: Version, - ManagementAddress: mgmtAddr, - ProxyURL: url, - CertificateDirectory: certDir, - GenerateACMECertificates: acmeCerts, - ACMEChallengeAddress: acmeAddr, - ACMEDirectory: acmeDir, - OIDCClientId: oidcId, - OIDCClientSecret: oidcSecret, - OIDCEndpoint: oidcEndpoint, - OIDCScopes: strings.Split(oidcScopes, ","), - } - - if err := srv.ListenAndServe(context.TODO(), addr); err != nil { - log.Fatal(err) - } + cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion) + cmd.Execute() } diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index c298b7f79..a1843c120 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -42,7 +42,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: xid.New().String(), ServiceId: capturedData.GetServiceId(), - AccountID: capturedData.GetAccountId(), + AccountID: string(capturedData.GetAccountId()), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go new file mode 100644 index 000000000..6b78a9b8a --- /dev/null +++ b/proxy/internal/debug/client.go @@ -0,0 +1,307 @@ +// Package debug provides HTTP debug endpoints and CLI client for the proxy server. +package debug + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// StatusFilters contains filter options for status queries. +type StatusFilters struct { + IPs []string + Names []string + Status string + ConnectionType string +} + +// Client provides CLI access to debug endpoints. +type Client struct { + baseURL string + jsonOutput bool + httpClient *http.Client + out io.Writer +} + +// NewClient creates a new debug client. +func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client { + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + baseURL = "http://" + baseURL + } + baseURL = strings.TrimSuffix(baseURL, "/") + + return &Client{ + baseURL: baseURL, + jsonOutput: jsonOutput, + out: out, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// Health fetches the health status. +func (c *Client) Health(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/health", c.printHealth) +} + +func (c *Client) printHealth(data map[string]any) { + _, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"]) + _, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"]) +} + +// ListClients fetches the list of all clients. +func (c *Client) ListClients(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/clients", c.printClients) +} + +func (c *Client) printClients(data map[string]any) { + _, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"]) + _, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"]) + + clients, ok := data["clients"].([]any) + if !ok || len(clients) == 0 { + _, _ = fmt.Fprintln(c.out, "No clients connected.") + return + } + + _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT") + _, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110)) + + for _, item := range clients { + c.printClientRow(item) + } +} + +func (c *Client) printClientRow(item any) { + client, ok := item.(map[string]any) + if !ok { + return + } + + domains := c.extractDomains(client) + hasClient := "no" + if hc, ok := client["has_client"].(bool); ok && hc { + hasClient = "yes" + } + + _, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n", + client["account_id"], + client["age"], + domains, + hasClient, + ) +} + +func (c *Client) extractDomains(client map[string]any) string { + d, ok := client["domains"].([]any) + if !ok || len(d) == 0 { + return "-" + } + + parts := make([]string, len(d)) + for i, domain := range d { + parts[i] = fmt.Sprint(domain) + } + return strings.Join(parts, ", ") +} + +// ClientStatus fetches the status of a specific client. +func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error { + params := url.Values{} + if len(filters.IPs) > 0 { + params.Set("filter-by-ips", strings.Join(filters.IPs, ",")) + } + if len(filters.Names) > 0 { + params.Set("filter-by-names", strings.Join(filters.Names, ",")) + } + if filters.Status != "" { + params.Set("filter-by-status", filters.Status) + } + if filters.ConnectionType != "" { + params.Set("filter-by-connection-type", filters.ConnectionType) + } + + path := "/debug/clients/" + url.PathEscape(accountID) + if len(params) > 0 { + path += "?" + params.Encode() + } + return c.fetchAndPrint(ctx, path, c.printClientStatus) +} + +func (c *Client) printClientStatus(data map[string]any) { + _, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"]) + if status, ok := data["status"].(string); ok { + _, _ = fmt.Fprint(c.out, status) + } +} + +// ClientSyncResponse fetches the sync response of a specific client. +func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error { + path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse" + return c.fetchAndPrintJSON(ctx, path) +} + +// PingTCP performs a TCP ping through a client. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { + params := url.Values{} + params.Set("host", host) + params.Set("port", fmt.Sprintf("%d", port)) + if timeout != "" { + params.Set("timeout", timeout) + } + + path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) + return c.fetchAndPrint(ctx, path, c.printPingResult) +} + +func (c *Client) printPingResult(data map[string]any) { + success, _ := data["success"].(bool) + if success { + _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) + } else { + _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + c.printError(data) + } +} + +// SetLogLevel sets the log level of a specific client. +func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error { + params := url.Values{} + params.Set("level", level) + + path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode()) + return c.fetchAndPrint(ctx, path, c.printLogLevelResult) +} + +func (c *Client) printLogLevelResult(data map[string]any) { + success, _ := data["success"].(bool) + if success { + _, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"]) + } else { + _, _ = fmt.Fprintln(c.out, "Failed to set log level") + c.printError(data) + } +} + +// StartClient starts a specific client. +func (c *Client) StartClient(ctx context.Context, accountID string) error { + path := "/debug/clients/" + url.PathEscape(accountID) + "/start" + return c.fetchAndPrint(ctx, path, c.printStartResult) +} + +func (c *Client) printStartResult(data map[string]any) { + success, _ := data["success"].(bool) + if success { + _, _ = fmt.Fprintln(c.out, "Client started") + } else { + _, _ = fmt.Fprintln(c.out, "Failed to start client") + c.printError(data) + } +} + +// StopClient stops a specific client. +func (c *Client) StopClient(ctx context.Context, accountID string) error { + path := "/debug/clients/" + url.PathEscape(accountID) + "/stop" + return c.fetchAndPrint(ctx, path, c.printStopResult) +} + +func (c *Client) printStopResult(data map[string]any) { + success, _ := data["success"].(bool) + if success { + _, _ = fmt.Fprintln(c.out, "Client stopped") + } else { + _, _ = fmt.Fprintln(c.out, "Failed to stop client") + c.printError(data) + } +} + +func (c *Client) printError(data map[string]any) { + if errMsg, ok := data["error"].(string); ok { + _, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg) + } +} + +func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error { + data, raw, err := c.fetch(ctx, path) + if err != nil { + return err + } + + if c.jsonOutput { + return c.writeJSON(data) + } + + if data != nil { + printer(data) + return nil + } + + _, _ = fmt.Fprintln(c.out, string(raw)) + return nil +} + +func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error { + data, raw, err := c.fetch(ctx, path) + if err != nil { + return err + } + + if data != nil { + return c.writeJSON(data) + } + + _, _ = fmt.Fprintln(c.out, string(raw)) + return nil +} + +func (c *Client) writeJSON(data map[string]any) error { + enc := json.NewEncoder(c.out) + enc.SetIndent("", " ") + return enc.Encode(data) +} + +func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) { + fullURL := c.baseURL + path + if !strings.Contains(path, "format=json") { + if strings.Contains(path, "?") { + fullURL += "&format=json" + } else { + fullURL += "?format=json" + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return nil, nil, fmt.Errorf("create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var data map[string]any + if err := json.Unmarshal(body, &data); err != nil { + return nil, body, nil + } + + return data, body, nil +} + diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go new file mode 100644 index 000000000..43cfb9533 --- /dev/null +++ b/proxy/internal/debug/handler.go @@ -0,0 +1,589 @@ +// Package debug provides HTTP debug endpoints for the proxy server. +package debug + +import ( + "context" + "embed" + "encoding/json" + "fmt" + "html/template" + "net/http" + "strconv" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protojson" + + nbembed "github.com/netbirdio/netbird/client/embed" + nbstatus "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/version" +) + +//go:embed templates/*.html +var templateFS embed.FS + +const defaultPingTimeout = 10 * time.Second + +// formatDuration formats a duration with 2 decimal places using appropriate units. +func formatDuration(d time.Duration) string { + switch { + case d >= time.Hour: + return fmt.Sprintf("%.2fh", d.Hours()) + case d >= time.Minute: + return fmt.Sprintf("%.2fm", d.Minutes()) + case d >= time.Second: + return fmt.Sprintf("%.2fs", d.Seconds()) + case d >= time.Millisecond: + return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000) + case d >= time.Microsecond: + return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000) + default: + return fmt.Sprintf("%dns", d.Nanoseconds()) + } +} + +// ClientProvider provides access to NetBird clients. +type ClientProvider interface { + GetClient(accountID types.AccountID) (*nbembed.Client, bool) + ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo +} + +// Handler provides HTTP debug endpoints. +type Handler struct { + provider ClientProvider + logger *log.Logger + startTime time.Time + templates *template.Template + templateMu sync.RWMutex +} + +// NewHandler creates a new debug handler. +func NewHandler(provider ClientProvider, logger *log.Logger) *Handler { + if logger == nil { + logger = log.StandardLogger() + } + h := &Handler{ + provider: provider, + logger: logger, + startTime: time.Now(), + } + if err := h.loadTemplates(); err != nil { + logger.Errorf("failed to load embedded templates: %v", err) + } + return h +} + +func (h *Handler) loadTemplates() error { + tmpl, err := template.ParseFS(templateFS, "templates/*.html") + if err != nil { + return fmt.Errorf("parse embedded templates: %w", err) + } + + h.templateMu.Lock() + h.templates = tmpl + h.templateMu.Unlock() + + return nil +} + +func (h *Handler) getTemplates() *template.Template { + h.templateMu.RLock() + defer h.templateMu.RUnlock() + return h.templates +} + +// ServeHTTP handles debug requests. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json") + path = strings.TrimSuffix(path, "/json") + + switch path { + case "/debug", "/debug/": + h.handleIndex(w, r, wantJSON) + case "/debug/clients": + h.handleListClients(w, r, wantJSON) + case "/debug/health": + h.handleHealth(w, r, wantJSON) + default: + if h.handleClientRoutes(w, r, path, wantJSON) { + return + } + http.NotFound(w, r) + } +} + +func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool { + if !strings.HasPrefix(path, "/debug/clients/") { + return false + } + + rest := strings.TrimPrefix(path, "/debug/clients/") + parts := strings.SplitN(rest, "/", 2) + accountID := types.AccountID(parts[0]) + + if len(parts) == 1 { + h.handleClientStatus(w, r, accountID, wantJSON) + return true + } + + switch parts[1] { + case "syncresponse": + h.handleClientSyncResponse(w, r, accountID, wantJSON) + case "tools": + h.handleClientTools(w, r, accountID) + case "pingtcp": + h.handlePingTCP(w, r, accountID) + case "loglevel": + h.handleLogLevel(w, r, accountID) + case "start": + h.handleClientStart(w, r, accountID) + case "stop": + h.handleClientStop(w, r, accountID) + default: + return false + } + return true +} + +type indexData struct { + Version string + Uptime string + ClientCount int + TotalDomains int + Clients []clientData +} + +type clientData struct { + AccountID string + Domains string + Age string + Status string +} + +func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) { + clients := h.provider.ListClientsForDebug() + + totalDomains := 0 + for _, info := range clients { + totalDomains += info.DomainCount + } + + if wantJSON { + clientsJSON := make([]map[string]interface{}, 0, len(clients)) + for _, info := range clients { + clientsJSON = append(clientsJSON, map[string]interface{}{ + "account_id": info.AccountID, + "domain_count": info.DomainCount, + "domains": info.Domains, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), + }) + } + h.writeJSON(w, map[string]interface{}{ + "version": version.NetbirdVersion(), + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "total_domains": totalDomains, + "clients": clientsJSON, + }) + return + } + + data := indexData{ + Version: version.NetbirdVersion(), + Uptime: time.Since(h.startTime).Round(time.Second).String(), + ClientCount: len(clients), + TotalDomains: totalDomains, + Clients: make([]clientData, 0, len(clients)), + } + + for _, info := range clients { + domains := info.Domains.SafeString() + if domains == "" { + domains = "-" + } + status := "No client" + if info.HasClient { + status = "Active" + } + data.Clients = append(data.Clients, clientData{ + AccountID: string(info.AccountID), + Domains: domains, + Age: time.Since(info.CreatedAt).Round(time.Second).String(), + Status: status, + }) + } + + h.renderTemplate(w, "index", data) +} + +type clientsData struct { + Uptime string + Clients []clientData +} + +func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) { + clients := h.provider.ListClientsForDebug() + + if wantJSON { + clientsJSON := make([]map[string]interface{}, 0, len(clients)) + for _, info := range clients { + clientsJSON = append(clientsJSON, map[string]interface{}{ + "account_id": info.AccountID, + "domain_count": info.DomainCount, + "domains": info.Domains, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), + }) + } + h.writeJSON(w, map[string]interface{}{ + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "clients": clientsJSON, + }) + return + } + + data := clientsData{ + Uptime: time.Since(h.startTime).Round(time.Second).String(), + Clients: make([]clientData, 0, len(clients)), + } + + for _, info := range clients { + domains := info.Domains.SafeString() + if domains == "" { + domains = "-" + } + status := "No client" + if info.HasClient { + status = "Active" + } + data.Clients = append(data.Clients, clientData{ + AccountID: string(info.AccountID), + Domains: domains, + Age: time.Since(info.CreatedAt).Round(time.Second).String(), + Status: status, + }) + } + + h.renderTemplate(w, "clients", data) +} + +type clientDetailData struct { + AccountID string + ActiveTab string + Content string +} + +func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound) + return + } + + fullStatus, err := client.Status() + if err != nil { + http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError) + return + } + + // Parse filter parameters + query := r.URL.Query() + statusFilter := query.Get("filter-by-status") + connectionTypeFilter := query.Get("filter-by-connection-type") + + var prefixNamesFilter []string + var prefixNamesFilterMap map[string]struct{} + if names := query.Get("filter-by-names"); names != "" { + prefixNamesFilter = strings.Split(names, ",") + prefixNamesFilterMap = make(map[string]struct{}) + for _, name := range prefixNamesFilter { + prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{} + } + } + + var ipsFilterMap map[string]struct{} + if ips := query.Get("filter-by-ips"); ips != "" { + ipsFilterMap = make(map[string]struct{}) + for _, ip := range strings.Split(ips, ",") { + ipsFilterMap[strings.TrimSpace(ip)] = struct{}{} + } + } + + pbStatus := nbstatus.ToProtoFullStatus(fullStatus) + overview := nbstatus.ConvertToStatusOutputOverview( + pbStatus, + false, + version.NetbirdVersion(), + statusFilter, + prefixNamesFilter, + prefixNamesFilterMap, + ipsFilterMap, + connectionTypeFilter, + "", + ) + + if wantJSON { + h.writeJSON(w, map[string]interface{}{ + "account_id": accountID, + "status": overview.FullDetailSummary(), + }) + return + } + + data := clientDetailData{ + AccountID: string(accountID), + ActiveTab: "status", + Content: overview.FullDetailSummary(), + } + + h.renderTemplate(w, "clientDetail", data) +} + +func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound) + return + } + + syncResp, err := client.GetLatestSyncResponse() + if err != nil { + http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError) + return + } + + if syncResp == nil { + http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound) + return + } + + opts := protojson.MarshalOptions{ + EmitUnpopulated: true, + UseProtoNames: true, + Indent: " ", + AllowPartial: true, + } + + jsonBytes, err := opts.Marshal(syncResp) + if err != nil { + http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError) + return + } + + if wantJSON { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jsonBytes) + return + } + + data := clientDetailData{ + AccountID: string(accountID), + ActiveTab: "syncresponse", + Content: string(jsonBytes), + } + + h.renderTemplate(w, "clientDetail", data) +} + +type toolsData struct { + AccountID string +} + +func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) { + _, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound) + return + } + + data := toolsData{ + AccountID: string(accountID), + } + + h.renderTemplate(w, "tools", data) +} + +func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + return + } + + host := r.URL.Query().Get("host") + portStr := r.URL.Query().Get("port") + if host == "" || portStr == "" { + h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"}) + return + } + + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + h.writeJSON(w, map[string]interface{}{"error": "invalid port"}) + return + } + + timeout := defaultPingTimeout + if t := r.URL.Query().Get("timeout"); t != "" { + if d, err := time.ParseDuration(t); err == nil { + timeout = d + } + } + + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + address := fmt.Sprintf("%s:%d", host, port) + start := time.Now() + + conn, err := client.Dial(ctx, "tcp", address) + if err != nil { + h.writeJSON(w, map[string]interface{}{ + "success": false, + "host": host, + "port": port, + "error": err.Error(), + }) + return + } + if err := conn.Close(); err != nil { + h.logger.Debugf("close tcp ping connection: %v", err) + } + + latency := time.Since(start) + h.writeJSON(w, map[string]interface{}{ + "success": true, + "host": host, + "port": port, + "latency_ms": latency.Milliseconds(), + "latency": formatDuration(latency), + }) +} + +func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + return + } + + level := r.URL.Query().Get("level") + if level == "" { + h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"}) + return + } + + if err := client.SetLogLevel(level); err != nil { + h.writeJSON(w, map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + return + } + + h.writeJSON(w, map[string]interface{}{ + "success": true, + "level": level, + }) +} + +const clientActionTimeout = 30 * time.Second + +func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout) + defer cancel() + + if err := client.Start(ctx); err != nil { + h.writeJSON(w, map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + return + } + + h.writeJSON(w, map[string]interface{}{ + "success": true, + "message": "client started", + }) +} + +func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout) + defer cancel() + + if err := client.Stop(ctx); err != nil { + h.writeJSON(w, map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + return + } + + h.writeJSON(w, map[string]interface{}{ + "success": true, + "message": "client stopped", + }) +} + +type healthData struct { + Uptime string +} + +func (h *Handler) handleHealth(w http.ResponseWriter, _ *http.Request, wantJSON bool) { + if wantJSON { + h.writeJSON(w, map[string]interface{}{ + "status": "ok", + "uptime": time.Since(h.startTime).Round(10 * time.Millisecond).String(), + }) + return + } + + data := healthData{ + Uptime: time.Since(h.startTime).Round(time.Second).String(), + } + + h.renderTemplate(w, "health", data) +} + +func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + tmpl := h.getTemplates() + if tmpl == nil { + http.Error(w, "Templates not loaded", http.StatusInternalServerError) + return + } + if err := tmpl.ExecuteTemplate(w, name, data); err != nil { + h.logger.Errorf("execute template %s: %v", name, err) + http.Error(w, "Template error", http.StatusInternalServerError) + } +} + +func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) { + w.Header().Set("Content-Type", "application/json") + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + if err := enc.Encode(v); err != nil { + h.logger.Errorf("encode JSON response: %v", err) + } +} diff --git a/proxy/internal/debug/templates/base.html b/proxy/internal/debug/templates/base.html new file mode 100644 index 000000000..737bd5b85 --- /dev/null +++ b/proxy/internal/debug/templates/base.html @@ -0,0 +1,101 @@ +{{define "style"}} +body { + font-family: monospace; + margin: 20px; + background: #1a1a1a; + color: #eee; +} +a { + color: #6cf; +} +h1, h2, h3 { + color: #fff; +} +.info { + color: #aaa; +} +table { + border-collapse: collapse; + margin: 10px 0; +} +th, td { + border: 1px solid #444; + padding: 8px; + text-align: left; +} +th { + background: #333; +} +.nav { + margin-bottom: 20px; +} +.nav a { + margin-right: 15px; + padding: 8px 16px; + background: #333; + text-decoration: none; + border-radius: 4px; +} +.nav a.active { + background: #6cf; + color: #000; +} +pre { + background: #222; + padding: 15px; + border-radius: 4px; + overflow-x: auto; + white-space: pre-wrap; +} +input, select, textarea { + background: #333; + color: #eee; + border: 1px solid #555; + padding: 8px; + border-radius: 4px; + font-family: monospace; +} +input:focus, select:focus, textarea:focus { + outline: none; + border-color: #6cf; +} +button { + background: #6cf; + color: #000; + border: none; + padding: 8px 16px; + border-radius: 4px; + cursor: pointer; + font-family: monospace; +} +button:hover { + background: #5be; +} +button:disabled { + background: #555; + color: #888; + cursor: not-allowed; +} +.form-group { + margin-bottom: 15px; +} +.form-group label { + display: block; + margin-bottom: 5px; + color: #aaa; +} +.form-row { + display: flex; + gap: 10px; + align-items: flex-end; +} +.result { + margin-top: 20px; +} +.success { + color: #5f5; +} +.error { + color: #f55; +} +{{end}} diff --git a/proxy/internal/debug/templates/client_detail.html b/proxy/internal/debug/templates/client_detail.html new file mode 100644 index 000000000..c58e26f6c --- /dev/null +++ b/proxy/internal/debug/templates/client_detail.html @@ -0,0 +1,19 @@ +{{define "clientDetail"}} + + + + Client {{.AccountID}} + + + +

Client: {{.AccountID}}

+ +
{{.Content}}
+ + +{{end}} diff --git a/proxy/internal/debug/templates/clients.html b/proxy/internal/debug/templates/clients.html new file mode 100644 index 000000000..68f286272 --- /dev/null +++ b/proxy/internal/debug/templates/clients.html @@ -0,0 +1,33 @@ +{{define "clients"}} + + + + Clients + + + +

All Clients

+

Uptime: {{.Uptime}} | ← Back

+ {{if .Clients}} + + + + + + + + {{range .Clients}} + + + + + + + {{end}} +
Account IDDomainsAgeStatus
{{.AccountID}}{{.Domains}}{{.Age}}{{.Status}}
+ {{else}} +

No clients connected

+ {{end}} + + +{{end}} diff --git a/proxy/internal/debug/templates/health.html b/proxy/internal/debug/templates/health.html new file mode 100644 index 000000000..f584f8357 --- /dev/null +++ b/proxy/internal/debug/templates/health.html @@ -0,0 +1,14 @@ +{{define "health"}} + + + + Health + + + +

OK

+

Uptime: {{.Uptime}}

+

← Back

+ + +{{end}} diff --git a/proxy/internal/debug/templates/index.html b/proxy/internal/debug/templates/index.html new file mode 100644 index 000000000..ac01e12e9 --- /dev/null +++ b/proxy/internal/debug/templates/index.html @@ -0,0 +1,40 @@ +{{define "index"}} + + + + NetBird Proxy Debug + + + +

NetBird Proxy Debug

+

Version: {{.Version}} | Uptime: {{.Uptime}}

+

Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})

+ {{if .Clients}} + + + + + + + + {{range .Clients}} + + + + + + + {{end}} +
Account IDDomainsAgeStatus
{{.AccountID}}{{.Domains}}{{.Age}}{{.Status}}
+ {{else}} +

No clients connected

+ {{end}} +

Endpoints

+ +

Add ?format=json or /json suffix for JSON output

+ + +{{end}} diff --git a/proxy/internal/debug/templates/tools.html b/proxy/internal/debug/templates/tools.html new file mode 100644 index 000000000..091b3e0a1 --- /dev/null +++ b/proxy/internal/debug/templates/tools.html @@ -0,0 +1,142 @@ +{{define "tools"}} + + + + Client {{.AccountID}} - Tools + + + +

Client: {{.AccountID}}

+ + +

Client Control

+
+
+ + +
+
+ + +
+
+
+ +

Log Level

+
+
+ + +
+
+ + +
+
+
+ +

TCP Ping

+
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + + +{{end}} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index b437a9610..36da03d30 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -3,6 +3,8 @@ package proxy import ( "context" "sync" + + "github.com/netbirdio/netbird/proxy/internal/types" ) type requestContextKey string @@ -18,7 +20,7 @@ const ( type CapturedData struct { mu sync.RWMutex ServiceId string - AccountId string + AccountId types.AccountID } // SetServiceId safely sets the service ID @@ -36,14 +38,14 @@ func (c *CapturedData) GetServiceId() string { } // SetAccountId safely sets the account ID -func (c *CapturedData) SetAccountId(accountId string) { +func (c *CapturedData) SetAccountId(accountId types.AccountID) { c.mu.Lock() defer c.mu.Unlock() c.AccountId = accountId } // GetAccountId safely gets the account ID -func (c *CapturedData) GetAccountId() string { +func (c *CapturedData) GetAccountId() types.AccountID { c.mu.RLock() defer c.mu.RUnlock() return c.AccountId @@ -76,13 +78,13 @@ func ServiceIdFromContext(ctx context.Context) string { } return serviceId } -func withAccountId(ctx context.Context, accountId string) context.Context { +func withAccountId(ctx context.Context, accountId types.AccountID) context.Context { return context.WithValue(ctx, accountIdKey, accountId) } -func AccountIdFromContext(ctx context.Context) string { +func AccountIdFromContext(ctx context.Context) types.AccountID { v := ctx.Value(accountIdKey) - accountId, ok := v.(string) + accountId, ok := v.(types.AccountID) if !ok { return "" } diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index e90850961..feb792f50 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -4,6 +4,8 @@ import ( "net/http" "net/http/httputil" "sync" + + "github.com/netbirdio/netbird/proxy/internal/roundtrip" ) type ReverseProxy struct { @@ -36,8 +38,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Set the serviceId in the context for later retrieval. ctx := withServiceId(r.Context(), serviceId) - // Set the accountId in the context for later retrieval. + // Set the accountId in the context for later retrieval (for middleware). ctx = withAccountId(ctx, accountID) + // Set the accountId in the context for the roundtripper to use. + ctx = roundtrip.WithAccountID(ctx, accountID) // Also populate captured data if it exists (allows middleware to read after handler completes). // This solves the problem of passing data UP the middleware chain: we put a mutable struct diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index be9c0ed29..dff03af2e 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -6,16 +6,18 @@ import ( "net/url" "sort" "strings" + + "github.com/netbirdio/netbird/proxy/internal/types" ) type Mapping struct { ID string - AccountID string + AccountID types.AccountID Host string Paths map[string]*url.URL } -func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, string, bool) { +func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) { p.mappingsMux.RLock() if p.mappings == nil { p.mappingsMux.RUnlock() @@ -27,10 +29,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string } defer p.mappingsMux.RUnlock() - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host + // Strip port from host if present (e.g., "external.test:8443" -> "external.test") + host := req.Host + if h, _, err := net.SplitHostPort(host); err == nil { + host = h } + m, exists := p.mappings[host] if !exists { return nil, "", "", false diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 5e591f75d..b869dba88 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -4,18 +4,39 @@ import ( "context" "errors" "fmt" - "io" - "net" "net/http" "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/util" ) -const deviceNamePrefix = "ingress-" +const deviceNamePrefix = "ingress-proxy-" + +// ErrNoAccountID is returned when a request context is missing the account ID. +var ErrNoAccountID = errors.New("no account ID in request context") + +// domainInfo holds metadata about a registered domain. +type domainInfo struct { + reverseProxyID string +} + +// clientEntry holds an embedded NetBird client and tracks which domains use it. +type clientEntry struct { + client *embed.Client + transport *http.Transport + domains map[domain.Domain]domainInfo + createdAt time.Time + started bool +} type statusNotifier interface { NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error @@ -23,147 +44,389 @@ type statusNotifier interface { // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. +// Clients are keyed by AccountID, allowing multiple domains to share the same connection. type NetBird struct { mgmtAddr string + proxyID string logger *log.Logger - clientsMux sync.RWMutex - clients map[string]*embed.Client - + clientsMux sync.RWMutex + clients map[types.AccountID]*clientEntry + initLogOnce sync.Once statusNotifier statusNotifier } -func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird { +// NewNetBird creates a new NetBird transport. +func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird { if logger == nil { logger = log.StandardLogger() } return &NetBird{ mgmtAddr: mgmtAddr, + proxyID: proxyID, logger: logger, - clients: make(map[string]*embed.Client), + clients: make(map[types.AccountID]*clientEntry), statusNotifier: notifier, } } -func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error { +// AddPeer registers a domain for an account. If the account doesn't have a client yet, +// one is created using the provided setup key. Multiple domains can share the same client. +func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error { + n.clientsMux.Lock() + + entry, exists := n.clients[accountID] + if exists { + // Client already exists for this account, just register the domain + entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID} + started := entry.started + n.clientsMux.Unlock() + + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).Debug("registered domain with existing client") + + // If client is already started, notify this domain as connected immediately + if started && n.statusNotifier != nil { + if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).WithError(err).Warn("failed to notify status for existing client") + } + } + return nil + } + + n.initLogOnce.Do(func() { + if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil { + n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err) + } + }) + + wgPort := 0 client, err := embed.New(embed.Options{ - DeviceName: deviceNamePrefix + domain, + DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.mgmtAddr, SetupKey: key, - LogOutput: io.Discard, + LogLevel: log.WarnLevel.String(), BlockInbound: true, + WireguardPort: &wgPort, }) if err != nil { + n.clientsMux.Unlock() return fmt.Errorf("create netbird client: %w", err) } + // Create a transport using the client dialer. We do this instead of using + // the client's HTTPClient to avoid issues with request validation that do + // not work with reverse proxied requests. + entry = &clientEntry{ + client: client, + domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}}, + transport: &http.Transport{ + DialContext: client.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + createdAt: time.Now(), + started: false, + } + n.clients[accountID] = entry + n.clientsMux.Unlock() + + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).Info("created new client for account") + // Attempt to start the client in the background, if this fails // then it is not ideal, but it isn't the end of the world because // we will try to start the client again before we use it. go func() { - startCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - err = client.Start(startCtx) - switch { - case errors.Is(err, context.DeadlineExceeded): - n.logger.Debug("netbird client timed out") - // This is not ideal, but we will try again later. - return - case err != nil: - n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.") + + if err := client.Start(startCtx); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + }).Debug("netbird client start timed out, will retry on first request") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + }).WithError(err).Error("failed to start netbird client") + } return } - // Notify management that tunnel is now active + // Mark client as started and notify all registered domains + n.clientsMux.Lock() + entry, exists := n.clients[accountID] + if exists { + entry.started = true + } + // Copy domain info while holding lock + var domainsToNotify []struct { + domain domain.Domain + reverseProxyID string + } + if exists { + for dom, info := range entry.domains { + domainsToNotify = append(domainsToNotify, struct { + domain domain.Domain + reverseProxyID string + }{domain: dom, reverseProxyID: info.reverseProxyID}) + } + } + n.clientsMux.Unlock() + + // Notify all domains that they're connected if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil { - n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection") - } else { - n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection") + for _, domInfo := range domainsToNotify { + if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": domInfo.domain, + }).WithError(err).Warn("failed to notify tunnel connection status") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": domInfo.domain, + }).Info("notified management about tunnel connection") + } } } }() - n.clientsMux.Lock() - defer n.clientsMux.Unlock() - n.clients[domain] = client return nil } -func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error { - n.clientsMux.RLock() - client, exists := n.clients[domain] - n.clientsMux.RUnlock() +// RemovePeer unregisters a domain from an account. The client is only stopped +// when no domains are using it anymore. +func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error { + n.clientsMux.Lock() + + entry, exists := n.clients[accountID] if !exists { - // Mission failed successfully! + n.clientsMux.Unlock() return nil } - if err := client.Stop(ctx); err != nil { - return fmt.Errorf("stop netbird client: %w", err) + + // Get domain info before deleting + domInfo, domainExists := entry.domains[d] + if !domainExists { + n.clientsMux.Unlock() + return nil } - // Notify management that tunnel is disconnected + delete(entry.domains, d) + + // If there are still domains using this client, keep it running + if len(entry.domains) > 0 { + n.clientsMux.Unlock() + + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + "remaining_domains": len(entry.domains), + }).Debug("unregistered domain, client still in use") + + // Notify this domain as disconnected + if n.statusNotifier != nil { + if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).WithError(err).Warn("failed to notify tunnel disconnection status") + } + } + return nil + } + + // No more domains using this client, stop it + n.logger.WithFields(log.Fields{ + "account_id": accountID, + }).Info("stopping client, no more domains") + + client := entry.client + transport := entry.transport + delete(n.clients, accountID) + n.clientsMux.Unlock() + + // Notify disconnection before stopping if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil { - n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection") - } else { - n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection") + if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).WithError(err).Warn("failed to notify tunnel disconnection status") } } - n.clientsMux.Lock() - defer n.clientsMux.Unlock() - delete(n.clients, domain) + transport.CloseIdleConnections() + + if err := client.Stop(ctx); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + }).WithError(err).Warn("failed to stop netbird client") + } + return nil } +// RoundTrip implements http.RoundTripper. It looks up the client for the account +// specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host + accountID := AccountIDFromContext(req.Context()) + if accountID == "" { + return nil, ErrNoAccountID } + + // Copy references while holding lock, then unlock early to avoid blocking + // other requests during the potentially slow RoundTrip. n.clientsMux.RLock() - client, exists := n.clients[host] - // Immediately unlock after retrieval here rather than defer to avoid - // the call to client.Do blocking other clients being used whilst one - // is in use. - n.clientsMux.RUnlock() + entry, exists := n.clients[accountID] if !exists { - return nil, fmt.Errorf("no peer connection found for host: %s", host) + n.clientsMux.RUnlock() + return nil, fmt.Errorf("no peer connection found for account: %s", accountID) } + client := entry.client + transport := entry.transport + n.clientsMux.RUnlock() // Attempt to start the client, if the client is already running then // it will return an error that we ignore, if this hits a timeout then // this request is unprocessable. - startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second) + startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second) defer cancel() - err = client.Start(startCtx) - switch { - case errors.Is(err, embed.ErrClientAlreadyStarted): - break - case err != nil: - return nil, fmt.Errorf("start netbird client: %w", err) + if err := client.Start(startCtx); err != nil { + if !errors.Is(err, embed.ErrClientAlreadyStarted) { + return nil, fmt.Errorf("start netbird client: %w", err) + } } n.logger.WithFields(log.Fields{ - "host": host, + "account_id": accountID, + "host": req.Host, "url": req.URL.String(), "requestURI": req.RequestURI, "method": req.Method, }).Debug("running roundtrip for peer connection") - // Create a new transport using the client dialer and perform the roundtrip. - // We do this instead of using the client HTTPClient to avoid issues around - // client request validation that do not work with the reverse proxied - // requests. - // Other values are simply copied from the http.DefaultTransport which the - // standard reverse proxy implementation would have used. - // TODO: tune this transport for our needs. - return (&http.Transport{ - DialContext: client.DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }).RoundTrip(req) + return transport.RoundTrip(req) +} + +// StopAll stops all clients. +func (n *NetBird) StopAll(ctx context.Context) error { + n.clientsMux.Lock() + defer n.clientsMux.Unlock() + + var merr *multierror.Error + for accountID, entry := range n.clients { + entry.transport.CloseIdleConnections() + if err := entry.client.Stop(ctx); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + }).WithError(err).Warn("failed to stop netbird client during shutdown") + merr = multierror.Append(merr, err) + } + } + maps.Clear(n.clients) + + return nberrors.FormatErrorOrNil(merr) +} + +// HasClient returns true if there is a client for the given account. +func (n *NetBird) HasClient(accountID types.AccountID) bool { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + _, exists := n.clients[accountID] + return exists +} + +// DomainCount returns the number of domains registered for the given account. +// Returns 0 if the account has no client. +func (n *NetBird) DomainCount(accountID types.AccountID) int { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + entry, exists := n.clients[accountID] + if !exists { + return 0 + } + return len(entry.domains) +} + +// ClientCount returns the total number of active clients. +func (n *NetBird) ClientCount() int { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + return len(n.clients) +} + +// GetClient returns the embed.Client for the given account ID. +func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + entry, exists := n.clients[accountID] + if !exists { + return nil, false + } + return entry.client, true +} + +// ClientDebugInfo contains debug information about a client. +type ClientDebugInfo struct { + AccountID types.AccountID + DomainCount int + Domains domain.List + HasClient bool + CreatedAt time.Time +} + +// ListClientsForDebug returns information about all clients for debug purposes. +func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + + result := make(map[types.AccountID]ClientDebugInfo) + for accountID, entry := range n.clients { + domains := make(domain.List, 0, len(entry.domains)) + for d := range entry.domains { + domains = append(domains, d) + } + result[accountID] = ClientDebugInfo{ + AccountID: accountID, + DomainCount: len(entry.domains), + Domains: domains, + HasClient: entry.client != nil, + CreatedAt: entry.createdAt, + } + } + return result +} + +// accountIDContextKey is the context key for storing the account ID. +type accountIDContextKey struct{} + +// WithAccountID adds the account ID to the context. +func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { + return context.WithValue(ctx, accountIDContextKey{}, accountID) +} + +// AccountIDFromContext retrieves the account ID from the context. +func AccountIDFromContext(ctx context.Context) types.AccountID { + v := ctx.Value(accountIDContextKey{}) + if v == nil { + return "" + } + accountID, ok := v.(types.AccountID) + if !ok { + return "" + } + return accountID } diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go new file mode 100644 index 000000000..b04600e4a --- /dev/null +++ b/proxy/internal/roundtrip/netbird_test.go @@ -0,0 +1,247 @@ +package roundtrip + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/domain" +) + +// mockNetBird creates a NetBird instance for testing without actually connecting. +// It uses an invalid management URL to prevent real connections. +func mockNetBird() *NetBird { + return NewNetBird("http://invalid.test:9999", "test-proxy", nil, nil) +} + +func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("account-1") + + // Initially no client exists. + assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer") + assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + + // Add first domain - this should create a new client. + // Note: This will fail to actually connect since we use an invalid URL, + // but the client entry should still be created. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + + assert.True(t, nb.HasClient(accountID), "should have client after AddPeer") + assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") +} + +func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("account-1") + + // Add first domain. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + assert.Equal(t, 1, nb.DomainCount(accountID)) + + // Add second domain for the same account - should reuse existing client. + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + require.NoError(t, err) + assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain") + + // Add third domain. + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + require.NoError(t, err) + assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain") + + // Still only one client. + assert.True(t, nb.HasClient(accountID)) +} + +func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) { + nb := mockNetBird() + account1 := types.AccountID("account-1") + account2 := types.AccountID("account-2") + + // Add domain for account 1. + err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + + // Add domain for account 2. + err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2") + require.NoError(t, err) + + // Both accounts should have their own clients. + assert.True(t, nb.HasClient(account1), "account1 should have client") + assert.True(t, nb.HasClient(account2), "account2 should have client") + assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1") + assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1") +} + +func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("account-1") + + // Add multiple domains. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + require.NoError(t, err) + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + require.NoError(t, err) + assert.Equal(t, 3, nb.DomainCount(accountID)) + + // Remove one domain - client should remain. + err = nb.RemovePeer(context.Background(), accountID, "domain1.test") + require.NoError(t, err) + assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain") + assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2") + + // Remove another domain - client should still remain. + err = nb.RemovePeer(context.Background(), accountID, "domain2.test") + require.NoError(t, err) + assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain") + assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") +} + +func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("account-1") + + // Add single domain. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + assert.True(t, nb.HasClient(accountID)) + + // Remove the only domain - client should be removed. + // Note: Stop() may fail since the client never actually connected, + // but the entry should still be removed from the map. + _ = nb.RemovePeer(context.Background(), accountID, "domain1.test") + + // After removing all domains, client should be gone. + assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain") + assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") +} + +func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("nonexistent-account") + + // Removing from non-existent account should not error. + err := nb.RemovePeer(context.Background(), accountID, "domain1.test") + assert.NoError(t, err, "removing from non-existent account should not error") +} + +func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("account-1") + + // Add one domain. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + require.NoError(t, err) + + // Remove non-existent domain - should not affect existing domain. + err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test")) + require.NoError(t, err) + + // Original domain should still be registered. + assert.True(t, nb.HasClient(accountID)) + assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain") +} + +func TestWithAccountID_AndAccountIDFromContext(t *testing.T) { + ctx := context.Background() + accountID := types.AccountID("test-account") + + // Initially no account ID in context. + retrieved := AccountIDFromContext(ctx) + assert.True(t, retrieved == "", "should be empty when not set") + + // Add account ID to context. + ctx = WithAccountID(ctx, accountID) + retrieved = AccountIDFromContext(ctx) + assert.Equal(t, accountID, retrieved, "should retrieve the same account ID") +} + +func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) { + // Create context with wrong type for account ID key. + ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string") + + retrieved := AccountIDFromContext(ctx) + assert.True(t, retrieved == "", "should return empty for wrong type") +} + +func TestNetBird_StopAll_StopsAllClients(t *testing.T) { + nb := mockNetBird() + account1 := types.AccountID("account-1") + account2 := types.AccountID("account-2") + account3 := types.AccountID("account-3") + + // Add domains for multiple accounts. + err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1") + require.NoError(t, err) + err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2") + require.NoError(t, err) + err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3") + require.NoError(t, err) + + assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients") + + // Stop all clients. + // Note: StopAll may return errors since clients never actually connected, + // but the clients should still be removed from the map. + _ = nb.StopAll(context.Background()) + + assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll") + assert.False(t, nb.HasClient(account1), "account1 should not have client") + assert.False(t, nb.HasClient(account2), "account2 should not have client") + assert.False(t, nb.HasClient(account3), "account3 should not have client") +} + +func TestNetBird_ClientCount(t *testing.T) { + nb := mockNetBird() + + assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients") + + // Add clients for different accounts. + err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1") + require.NoError(t, err) + assert.Equal(t, 1, nb.ClientCount()) + + err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2") + require.NoError(t, err) + assert.Equal(t, 2, nb.ClientCount()) + + // Adding domain to existing account should not increase count. + err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b") + require.NoError(t, err) + assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count") +} + +func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) { + nb := mockNetBird() + + // Create a request without account ID in context. + req, err := http.NewRequest("GET", "http://example.com/", nil) + require.NoError(t, err) + + // RoundTrip should fail because no account ID in context. + _, err = nb.RoundTrip(req) + require.ErrorIs(t, err, ErrNoAccountID) +} + +func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) { + nb := mockNetBird() + accountID := types.AccountID("nonexistent-account") + + // Create a request with account ID but no client exists. + req, err := http.NewRequest("GET", "http://example.com/", nil) + require.NoError(t, err) + req = req.WithContext(WithAccountID(req.Context(), accountID)) + + // RoundTrip should fail because no client for this account. + _, err = nb.RoundTrip(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no peer connection found for account") +} diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go new file mode 100644 index 000000000..41acfef40 --- /dev/null +++ b/proxy/internal/types/types.go @@ -0,0 +1,5 @@ +// Package types defines common types used across the proxy package. +package types + +// AccountID represents a unique identifier for a NetBird account. +type AccountID string diff --git a/proxy/server.go b/proxy/server.go index 5bfdbe629..76c86f9e2 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -31,8 +31,11 @@ import ( "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" + "github.com/netbirdio/netbird/proxy/internal/debug" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) @@ -45,6 +48,7 @@ type Server struct { auth *auth.Middleware http *http.Server https *http.Server + debug *http.Server // Mostly used for debugging on management. startTime time.Time @@ -62,6 +66,11 @@ type Server struct { OIDCClientSecret string OIDCEndpoint string OIDCScopes []string + + // DebugEndpointEnabled enables the debug HTTP endpoint. + DebugEndpointEnabled bool + // DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444"). + DebugEndpointAddress string } // NotifyStatus sends a status update to management about tunnel connectivity @@ -148,7 +157,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Initialize the netbird client, this is required to build peer connections // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger, s) + s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.Logger, s) // When generating ACME certificates, start a challenge server. tlsConfig := &tls.Config{} @@ -204,6 +213,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Configure Access logs to management server. accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger) + if s.DebugEndpointEnabled { + debugAddr := debugEndpointAddr(s.DebugEndpointAddress) + debugHandler := debug.NewHandler(s.netbird, s.Logger) + s.debug = &http.Server{ + Addr: debugAddr, + Handler: debugHandler, + } + go func() { + s.Logger.WithField("address", debugAddr).Info("starting debug endpoint") + if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Errorf("debug endpoint error: %v", err) + } + }() + defer func() { + if err := s.debug.Close(); err != nil { + s.Logger.Debugf("debug endpoint close: %v", err) + } + }() + } + + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.netbird.StopAll(stopCtx); err != nil { + s.Logger.Warnf("failed to stop all netbird clients: %v", err) + } + }() + // Finally, start the reverse proxy. s.https = &http.Server{ Addr: addr, @@ -306,15 +343,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr } func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - domain := mapping.GetDomain() - accountID := mapping.GetAccountId() + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) reverseProxyID := mapping.GetId() - if err := s.netbird.AddPeer(ctx, domain, mapping.GetSetupKey(), accountID, reverseProxyID); err != nil { - return fmt.Errorf("create peer for domain %q: %w", domain, err) + if err := s.netbird.AddPeer(ctx, accountID, d, mapping.GetSetupKey(), reverseProxyID); err != nil { + return fmt.Errorf("create peer for domain %q: %w", d, err) } if s.acme != nil { - s.acme.AddDomain(domain, accountID, reverseProxyID) + s.acme.AddDomain(string(d), string(accountID), reverseProxyID) } // Pass the mapping through to the update function to avoid duplicating the @@ -360,10 +397,13 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) } func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { - if err := s.netbird.RemovePeer(ctx, mapping.GetDomain(), mapping.GetAccountId(), mapping.GetId()); err != nil { + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) + if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil { s.Logger.WithFields(log.Fields{ - "domain": mapping.GetDomain(), - "error": err, + "account_id": accountID, + "domain": d, + "error": err, }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") } if s.acme != nil { @@ -392,8 +432,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { } return proxy.Mapping{ ID: mapping.GetId(), - AccountID: mapping.AccountId, + AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, } } + +// debugEndpointAddr returns the address for the debug endpoint. +// If addr is empty, it defaults to localhost:8444 for security. +func debugEndpointAddr(addr string) string { + if addr == "" { + return "localhost:8444" + } + return addr +} diff --git a/proxy/server_test.go b/proxy/server_test.go new file mode 100644 index 000000000..b4fb4f8ba --- /dev/null +++ b/proxy/server_test.go @@ -0,0 +1,48 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDebugEndpointDisabledByDefault(t *testing.T) { + s := &Server{} + assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default") +} + +func TestDebugEndpointAddr(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty defaults to localhost", + input: "", + expected: "localhost:8444", + }, + { + name: "explicit localhost preserved", + input: "localhost:9999", + expected: "localhost:9999", + }, + { + name: "explicit address preserved", + input: "0.0.0.0:8444", + expected: "0.0.0.0:8444", + }, + { + name: "127.0.0.1 preserved", + input: "127.0.0.1:8444", + expected: "127.0.0.1:8444", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := debugEndpointAddr(tc.input) + assert.Equal(t, tc.expected, result) + }) + } +}