mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients - Add types package with AccountID type - Refactor netbird roundtrip to key clients by AccountID - Multiple domains can share the same client per account - Add status notifier for tunnel connection updates - Add OIDC flags to CLI - Add tests for netbird client management
This commit is contained in:
166
proxy/cmd/proxy/cmd/debug.go
Normal file
166
proxy/cmd/proxy/cmd/debug.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
)
|
||||
|
||||
var (
|
||||
debugAddr string
|
||||
jsonOutput bool
|
||||
|
||||
// status filters
|
||||
statusFilterByIPs []string
|
||||
statusFilterByNames []string
|
||||
statusFilterByStatus string
|
||||
statusFilterByConnectionType string
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debug commands for inspecting proxy state",
|
||||
Long: "Debug commands for inspecting the reverse proxy state via the debug HTTP endpoint.",
|
||||
}
|
||||
|
||||
var debugHealthCmd = &cobra.Command{
|
||||
Use: "health",
|
||||
Short: "Show proxy health status",
|
||||
RunE: runDebugHealth,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugClientsCmd = &cobra.Command{
|
||||
Use: "clients",
|
||||
Aliases: []string{"list"},
|
||||
Short: "List all connected clients",
|
||||
RunE: runDebugClients,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugStatusCmd = &cobra.Command{
|
||||
Use: "status <account-id>",
|
||||
Short: "Show client status",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugStatus,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugSyncCmd = &cobra.Command{
|
||||
Use: "sync-response <account-id>",
|
||||
Short: "Show client sync response",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugSync,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var pingTimeout string
|
||||
|
||||
var debugPingCmd = &cobra.Command{
|
||||
Use: "ping <account-id> <host> [port]",
|
||||
Short: "TCP ping through a client",
|
||||
Long: "Perform a TCP ping through a client's network to test connectivity.\nPort defaults to 80 if not specified.",
|
||||
Args: cobra.RangeArgs(2, 3),
|
||||
RunE: runDebugPing,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugLogLevelCmd = &cobra.Command{
|
||||
Use: "loglevel <account-id> <level>",
|
||||
Short: "Set client log level",
|
||||
Long: "Set the log level for a client (trace, debug, info, warn, error).",
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: runDebugLogLevel,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugStartCmd = &cobra.Command{
|
||||
Use: "start <account-id>",
|
||||
Short: "Start a client",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugStart,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugStopCmd = &cobra.Command{
|
||||
Use: "stop <account-id>",
|
||||
Short: "Stop a client",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugStop,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||
|
||||
debugStatusCmd.Flags().StringSliceVar(&statusFilterByIPs, "filter-by-ips", nil, "Filter by peer IPs (comma-separated)")
|
||||
debugStatusCmd.Flags().StringSliceVar(&statusFilterByNames, "filter-by-names", nil, "Filter by peer names (comma-separated)")
|
||||
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
|
||||
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
|
||||
|
||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||
|
||||
debugCmd.AddCommand(debugHealthCmd)
|
||||
debugCmd.AddCommand(debugClientsCmd)
|
||||
debugCmd.AddCommand(debugStatusCmd)
|
||||
debugCmd.AddCommand(debugSyncCmd)
|
||||
debugCmd.AddCommand(debugPingCmd)
|
||||
debugCmd.AddCommand(debugLogLevelCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
|
||||
func getDebugClient(cmd *cobra.Command) *debug.Client {
|
||||
return debug.NewClient(debugAddr, jsonOutput, cmd.OutOrStdout())
|
||||
}
|
||||
|
||||
func runDebugHealth(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).Health(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugClients(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).ListClients(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugStatus(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).ClientStatus(cmd.Context(), args[0], debug.StatusFilters{
|
||||
IPs: statusFilterByIPs,
|
||||
Names: statusFilterByNames,
|
||||
Status: statusFilterByStatus,
|
||||
ConnectionType: statusFilterByConnectionType,
|
||||
})
|
||||
}
|
||||
|
||||
func runDebugSync(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).ClientSyncResponse(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugPing(cmd *cobra.Command, args []string) error {
|
||||
port := 80
|
||||
if len(args) > 2 {
|
||||
p, err := strconv.Atoi(args[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = p
|
||||
}
|
||||
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout)
|
||||
}
|
||||
|
||||
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).SetLogLevel(cmd.Context(), args[0], args[1])
|
||||
}
|
||||
|
||||
func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StartClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
137
proxy/cmd/proxy/cmd/root.go
Normal file
137
proxy/cmd/proxy/cmd/root.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
var (
|
||||
Version = "dev"
|
||||
Commit = "unknown"
|
||||
BuildDate = "unknown"
|
||||
GoVersion = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
proxyURL string
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
oidcClientID string
|
||||
oidcClientSecret string
|
||||
oidcEndpoint string
|
||||
oidcScopes string
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "proxy",
|
||||
Short: "NetBird reverse proxy server",
|
||||
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
|
||||
Version: Version,
|
||||
RunE: runServer,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
||||
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
|
||||
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
|
||||
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached")
|
||||
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
|
||||
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
|
||||
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
|
||||
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
|
||||
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// SetVersionInfo sets version information for the CLI.
|
||||
func SetVersionInfo(version, commit, buildDate, goVersion string) {
|
||||
Version = version
|
||||
Commit = commit
|
||||
BuildDate = buildDate
|
||||
GoVersion = goVersion
|
||||
rootCmd.Version = version
|
||||
rootCmd.SetVersionTemplate("Version: {{.Version}}, Commit: " + Commit + ", BuildDate: " + BuildDate + ", Go: " + GoVersion + "\n")
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
level := "error"
|
||||
if debugLogs {
|
||||
level = "debug"
|
||||
}
|
||||
logger := log.New()
|
||||
|
||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||
|
||||
log.Infof("configured log level: %s", level)
|
||||
|
||||
srv := proxy.Server{
|
||||
Logger: logger,
|
||||
Version: Version,
|
||||
ManagementAddress: mgmtAddr,
|
||||
ProxyURL: proxyURL,
|
||||
CertificateDirectory: certDir,
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
ACMEDirectory: acmeDir,
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
OIDCClientId: oidcClientID,
|
||||
OIDCClientSecret: oidcClientSecret,
|
||||
OIDCEndpoint: oidcEndpoint,
|
||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func envBoolOrDefault(key string, def bool) bool {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
parsed, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func envStringOrDefault(key string, def string) string {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -1,22 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/cmd/proxy/cmd"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
var (
|
||||
// Version is the application version (set via ldflags during build)
|
||||
Version = "dev"
|
||||
@@ -31,78 +20,7 @@ var (
|
||||
GoVersion = runtime.Version()
|
||||
)
|
||||
|
||||
func envBoolOrDefault(key string, def bool) bool {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
return v == strings.ToLower("true")
|
||||
}
|
||||
|
||||
func envStringOrDefault(key string, def string) string {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
version, debug bool
|
||||
mgmtAddr, addr, url, certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr, acmeDir string
|
||||
oidcId, oidcSecret, oidcEndpoint, oidcScopes string
|
||||
)
|
||||
|
||||
flag.BoolVar(&version, "v", false, "Print version and exit")
|
||||
flag.BoolVar(&debug, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
||||
flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.")
|
||||
flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.")
|
||||
flag.StringVar(&url, "url", envStringOrDefault("NB_PROXY_URL", "proxy.netbird.io"), "The URL at which this proxy will be reached, where CNAME records for proxied endpoints will be directed.")
|
||||
flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ")
|
||||
flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.")
|
||||
flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.")
|
||||
flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.")
|
||||
flag.StringVar(&oidcId, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
||||
flag.StringVar(&oidcSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
||||
flag.StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
||||
flag.StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||
flag.Parse()
|
||||
|
||||
if version {
|
||||
fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Configure logrus.
|
||||
level := "error"
|
||||
if debug {
|
||||
level = "debug"
|
||||
}
|
||||
logger := log.New()
|
||||
|
||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||
|
||||
log.Infof("configured log level: %s", level)
|
||||
|
||||
srv := proxy.Server{
|
||||
Logger: logger,
|
||||
Version: Version,
|
||||
ManagementAddress: mgmtAddr,
|
||||
ProxyURL: url,
|
||||
CertificateDirectory: certDir,
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
ACMEDirectory: acmeDir,
|
||||
OIDCClientId: oidcId,
|
||||
OIDCClientSecret: oidcSecret,
|
||||
OIDCEndpoint: oidcEndpoint,
|
||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion)
|
||||
cmd.Execute()
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
entry := logEntry{
|
||||
ID: xid.New().String(),
|
||||
ServiceId: capturedData.GetServiceId(),
|
||||
AccountID: capturedData.GetAccountId(),
|
||||
AccountID: string(capturedData.GetAccountId()),
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
|
||||
307
proxy/internal/debug/client.go
Normal file
307
proxy/internal/debug/client.go
Normal file
@@ -0,0 +1,307 @@
|
||||
// Package debug provides HTTP debug endpoints and CLI client for the proxy server.
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
type StatusFilters struct {
|
||||
IPs []string
|
||||
Names []string
|
||||
Status string
|
||||
ConnectionType string
|
||||
}
|
||||
|
||||
// Client provides CLI access to debug endpoints.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
jsonOutput bool
|
||||
httpClient *http.Client
|
||||
out io.Writer
|
||||
}
|
||||
|
||||
// NewClient creates a new debug client.
|
||||
func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client {
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
baseURL = "http://" + baseURL
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
jsonOutput: jsonOutput,
|
||||
out: out,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Health fetches the health status.
|
||||
func (c *Client) Health(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/health", c.printHealth)
|
||||
}
|
||||
|
||||
func (c *Client) printHealth(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"])
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
}
|
||||
|
||||
// ListClients fetches the list of all clients.
|
||||
func (c *Client) ListClients(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/clients", c.printClients)
|
||||
}
|
||||
|
||||
func (c *Client) printClients(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"])
|
||||
|
||||
clients, ok := data["clients"].([]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
_, _ = fmt.Fprintln(c.out, "No clients connected.")
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
|
||||
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
|
||||
|
||||
for _, item := range clients {
|
||||
c.printClientRow(item)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printClientRow(item any) {
|
||||
client, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
domains := c.extractDomains(client)
|
||||
hasClient := "no"
|
||||
if hc, ok := client["has_client"].(bool); ok && hc {
|
||||
hasClient = "yes"
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
|
||||
client["account_id"],
|
||||
client["age"],
|
||||
domains,
|
||||
hasClient,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Client) extractDomains(client map[string]any) string {
|
||||
d, ok := client["domains"].([]any)
|
||||
if !ok || len(d) == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
parts := make([]string, len(d))
|
||||
for i, domain := range d {
|
||||
parts[i] = fmt.Sprint(domain)
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// ClientStatus fetches the status of a specific client.
|
||||
func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error {
|
||||
params := url.Values{}
|
||||
if len(filters.IPs) > 0 {
|
||||
params.Set("filter-by-ips", strings.Join(filters.IPs, ","))
|
||||
}
|
||||
if len(filters.Names) > 0 {
|
||||
params.Set("filter-by-names", strings.Join(filters.Names, ","))
|
||||
}
|
||||
if filters.Status != "" {
|
||||
params.Set("filter-by-status", filters.Status)
|
||||
}
|
||||
if filters.ConnectionType != "" {
|
||||
params.Set("filter-by-connection-type", filters.ConnectionType)
|
||||
}
|
||||
|
||||
path := "/debug/clients/" + url.PathEscape(accountID)
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
return c.fetchAndPrint(ctx, path, c.printClientStatus)
|
||||
}
|
||||
|
||||
func (c *Client) printClientStatus(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
|
||||
if status, ok := data["status"].(string); ok {
|
||||
_, _ = fmt.Fprint(c.out, status)
|
||||
}
|
||||
}
|
||||
|
||||
// ClientSyncResponse fetches the sync response of a specific client.
|
||||
func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse"
|
||||
return c.fetchAndPrintJSON(ctx, path)
|
||||
}
|
||||
|
||||
// PingTCP performs a TCP ping through a client.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
||||
params := url.Values{}
|
||||
params.Set("host", host)
|
||||
params.Set("port", fmt.Sprintf("%d", port))
|
||||
if timeout != "" {
|
||||
params.Set("timeout", timeout)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||
return c.fetchAndPrint(ctx, path, c.printPingResult)
|
||||
}
|
||||
|
||||
func (c *Client) printPingResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
||||
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level of a specific client.
|
||||
func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error {
|
||||
params := url.Values{}
|
||||
params.Set("level", level)
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode())
|
||||
return c.fetchAndPrint(ctx, path, c.printLogLevelResult)
|
||||
}
|
||||
|
||||
func (c *Client) printLogLevelResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to set log level")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// StartClient starts a specific client.
|
||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||
return c.fetchAndPrint(ctx, path, c.printStartResult)
|
||||
}
|
||||
|
||||
func (c *Client) printStartResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintln(c.out, "Client started")
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to start client")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// StopClient stops a specific client.
|
||||
func (c *Client) StopClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/stop"
|
||||
return c.fetchAndPrint(ctx, path, c.printStopResult)
|
||||
}
|
||||
|
||||
func (c *Client) printStopResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintln(c.out, "Client stopped")
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to stop client")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printError(data map[string]any) {
|
||||
if errMsg, ok := data["error"].(string); ok {
|
||||
_, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.jsonOutput {
|
||||
return c.writeJSON(data)
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
printer(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
return c.writeJSON(data)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) writeJSON(data map[string]any) error {
|
||||
enc := json.NewEncoder(c.out)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) {
|
||||
fullURL := c.baseURL + path
|
||||
if !strings.Contains(path, "format=json") {
|
||||
if strings.Contains(path, "?") {
|
||||
fullURL += "&format=json"
|
||||
} else {
|
||||
fullURL += "?format=json"
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, body, nil
|
||||
}
|
||||
|
||||
return data, body, nil
|
||||
}
|
||||
|
||||
589
proxy/internal/debug/handler.go
Normal file
589
proxy/internal/debug/handler.go
Normal file
@@ -0,0 +1,589 @@
|
||||
// Package debug provides HTTP debug endpoints for the proxy server.
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
nbembed "github.com/netbirdio/netbird/client/embed"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
//go:embed templates/*.html
|
||||
var templateFS embed.FS
|
||||
|
||||
const defaultPingTimeout = 10 * time.Second
|
||||
|
||||
// formatDuration formats a duration with 2 decimal places using appropriate units.
|
||||
func formatDuration(d time.Duration) string {
|
||||
switch {
|
||||
case d >= time.Hour:
|
||||
return fmt.Sprintf("%.2fh", d.Hours())
|
||||
case d >= time.Minute:
|
||||
return fmt.Sprintf("%.2fm", d.Minutes())
|
||||
case d >= time.Second:
|
||||
return fmt.Sprintf("%.2fs", d.Seconds())
|
||||
case d >= time.Millisecond:
|
||||
return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000)
|
||||
case d >= time.Microsecond:
|
||||
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000)
|
||||
default:
|
||||
return fmt.Sprintf("%dns", d.Nanoseconds())
|
||||
}
|
||||
}
|
||||
|
||||
// ClientProvider provides access to NetBird clients.
|
||||
type ClientProvider interface {
|
||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
}
|
||||
|
||||
// Handler provides HTTP debug endpoints.
|
||||
type Handler struct {
|
||||
provider ClientProvider
|
||||
logger *log.Logger
|
||||
startTime time.Time
|
||||
templates *template.Template
|
||||
templateMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new debug handler.
|
||||
func NewHandler(provider ClientProvider, logger *log.Logger) *Handler {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
h := &Handler{
|
||||
provider: provider,
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
if err := h.loadTemplates(); err != nil {
|
||||
logger.Errorf("failed to load embedded templates: %v", err)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handler) loadTemplates() error {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse embedded templates: %w", err)
|
||||
}
|
||||
|
||||
h.templateMu.Lock()
|
||||
h.templates = tmpl
|
||||
h.templateMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) getTemplates() *template.Template {
|
||||
h.templateMu.RLock()
|
||||
defer h.templateMu.RUnlock()
|
||||
return h.templates
|
||||
}
|
||||
|
||||
// ServeHTTP handles debug requests.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json")
|
||||
path = strings.TrimSuffix(path, "/json")
|
||||
|
||||
switch path {
|
||||
case "/debug", "/debug/":
|
||||
h.handleIndex(w, r, wantJSON)
|
||||
case "/debug/clients":
|
||||
h.handleListClients(w, r, wantJSON)
|
||||
case "/debug/health":
|
||||
h.handleHealth(w, r, wantJSON)
|
||||
default:
|
||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool {
|
||||
if !strings.HasPrefix(path, "/debug/clients/") {
|
||||
return false
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(path, "/debug/clients/")
|
||||
parts := strings.SplitN(rest, "/", 2)
|
||||
accountID := types.AccountID(parts[0])
|
||||
|
||||
if len(parts) == 1 {
|
||||
h.handleClientStatus(w, r, accountID, wantJSON)
|
||||
return true
|
||||
}
|
||||
|
||||
switch parts[1] {
|
||||
case "syncresponse":
|
||||
h.handleClientSyncResponse(w, r, accountID, wantJSON)
|
||||
case "tools":
|
||||
h.handleClientTools(w, r, accountID)
|
||||
case "pingtcp":
|
||||
h.handlePingTCP(w, r, accountID)
|
||||
case "loglevel":
|
||||
h.handleLogLevel(w, r, accountID)
|
||||
case "start":
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type indexData struct {
|
||||
Version string
|
||||
Uptime string
|
||||
ClientCount int
|
||||
TotalDomains int
|
||||
Clients []clientData
|
||||
}
|
||||
|
||||
type clientData struct {
|
||||
AccountID string
|
||||
Domains string
|
||||
Age string
|
||||
Status string
|
||||
}
|
||||
|
||||
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
|
||||
totalDomains := 0
|
||||
for _, info := range clients {
|
||||
totalDomains += info.DomainCount
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, info := range clients {
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"total_domains": totalDomains,
|
||||
"clients": clientsJSON,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := indexData{
|
||||
Version: version.NetbirdVersion(),
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
ClientCount: len(clients),
|
||||
TotalDomains: totalDomains,
|
||||
Clients: make([]clientData, 0, len(clients)),
|
||||
}
|
||||
|
||||
for _, info := range clients {
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
status = "Active"
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "index", data)
|
||||
}
|
||||
|
||||
type clientsData struct {
|
||||
Uptime string
|
||||
Clients []clientData
|
||||
}
|
||||
|
||||
func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, info := range clients {
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := clientsData{
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
Clients: make([]clientData, 0, len(clients)),
|
||||
}
|
||||
|
||||
for _, info := range clients {
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
status = "Active"
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clients", data)
|
||||
}
|
||||
|
||||
type clientDetailData struct {
|
||||
AccountID string
|
||||
ActiveTab string
|
||||
Content string
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse filter parameters
|
||||
query := r.URL.Query()
|
||||
statusFilter := query.Get("filter-by-status")
|
||||
connectionTypeFilter := query.Get("filter-by-connection-type")
|
||||
|
||||
var prefixNamesFilter []string
|
||||
var prefixNamesFilterMap map[string]struct{}
|
||||
if names := query.Get("filter-by-names"); names != "" {
|
||||
prefixNamesFilter = strings.Split(names, ",")
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
for _, name := range prefixNamesFilter {
|
||||
prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var ipsFilterMap map[string]struct{}
|
||||
if ips := query.Get("filter-by-ips"); ips != "" {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
for _, ip := range strings.Split(ips, ",") {
|
||||
ipsFilterMap[strings.TrimSpace(ip)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
pbStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(
|
||||
pbStatus,
|
||||
false,
|
||||
version.NetbirdVersion(),
|
||||
statusFilter,
|
||||
prefixNamesFilter,
|
||||
prefixNamesFilterMap,
|
||||
ipsFilterMap,
|
||||
connectionTypeFilter,
|
||||
"",
|
||||
)
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := clientDetailData{
|
||||
AccountID: string(accountID),
|
||||
ActiveTab: "status",
|
||||
Content: overview.FullDetailSummary(),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if syncResp == nil {
|
||||
http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
opts := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
Indent: " ",
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
jsonBytes, err := opts.Marshal(syncResp)
|
||||
if err != nil {
|
||||
http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jsonBytes)
|
||||
return
|
||||
}
|
||||
|
||||
data := clientDetailData{
|
||||
AccountID: string(accountID),
|
||||
ActiveTab: "syncresponse",
|
||||
Content: string(jsonBytes),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
type toolsData struct {
|
||||
AccountID string
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) {
|
||||
_, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
data := toolsData{
|
||||
AccountID: string(accountID),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "tools", data)
|
||||
}
|
||||
|
||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := r.URL.Query().Get("host")
|
||||
portStr := r.URL.Query().Get("port")
|
||||
if host == "" || portStr == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||
return
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||
return
|
||||
}
|
||||
|
||||
timeout := defaultPingTimeout
|
||||
if t := r.URL.Query().Get("timeout"); t != "" {
|
||||
if d, err := time.ParseDuration(t); err == nil {
|
||||
timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", host, port)
|
||||
start := time.Now()
|
||||
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
level := r.URL.Query().Get("level")
|
||||
if level == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.SetLogLevel(level); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"level": level,
|
||||
})
|
||||
}
|
||||
|
||||
const clientActionTimeout = 30 * time.Second
|
||||
|
||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "client started",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "client stopped",
|
||||
})
|
||||
}
|
||||
|
||||
type healthData struct {
|
||||
Uptime string
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"status": "ok",
|
||||
"uptime": time.Since(h.startTime).Round(10 * time.Millisecond).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := healthData{
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "health", data)
|
||||
}
|
||||
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl := h.getTemplates()
|
||||
if tmpl == nil {
|
||||
http.Error(w, "Templates not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
|
||||
h.logger.Errorf("execute template %s: %v", name, err)
|
||||
http.Error(w, "Template error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(v); err != nil {
|
||||
h.logger.Errorf("encode JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
101
proxy/internal/debug/templates/base.html
Normal file
101
proxy/internal/debug/templates/base.html
Normal file
@@ -0,0 +1,101 @@
|
||||
{{define "style"}}
|
||||
body {
|
||||
font-family: monospace;
|
||||
margin: 20px;
|
||||
background: #1a1a1a;
|
||||
color: #eee;
|
||||
}
|
||||
a {
|
||||
color: #6cf;
|
||||
}
|
||||
h1, h2, h3 {
|
||||
color: #fff;
|
||||
}
|
||||
.info {
|
||||
color: #aaa;
|
||||
}
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
margin: 10px 0;
|
||||
}
|
||||
th, td {
|
||||
border: 1px solid #444;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
th {
|
||||
background: #333;
|
||||
}
|
||||
.nav {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.nav a {
|
||||
margin-right: 15px;
|
||||
padding: 8px 16px;
|
||||
background: #333;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.nav a.active {
|
||||
background: #6cf;
|
||||
color: #000;
|
||||
}
|
||||
pre {
|
||||
background: #222;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
input, select, textarea {
|
||||
background: #333;
|
||||
color: #eee;
|
||||
border: 1px solid #555;
|
||||
padding: 8px;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
input:focus, select:focus, textarea:focus {
|
||||
outline: none;
|
||||
border-color: #6cf;
|
||||
}
|
||||
button {
|
||||
background: #6cf;
|
||||
color: #000;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-family: monospace;
|
||||
}
|
||||
button:hover {
|
||||
background: #5be;
|
||||
}
|
||||
button:disabled {
|
||||
background: #555;
|
||||
color: #888;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
margin-bottom: 5px;
|
||||
color: #aaa;
|
||||
}
|
||||
.form-row {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: flex-end;
|
||||
}
|
||||
.result {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.success {
|
||||
color: #5f5;
|
||||
}
|
||||
.error {
|
||||
color: #f55;
|
||||
}
|
||||
{{end}}
|
||||
19
proxy/internal/debug/templates/client_detail.html
Normal file
19
proxy/internal/debug/templates/client_detail.html
Normal file
@@ -0,0 +1,19 @@
|
||||
{{define "clientDetail"}}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Client {{.AccountID}}</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Client: {{.AccountID}}</h1>
|
||||
<div class="nav">
|
||||
<a href="/debug">← Back</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/tools"{{if eq .ActiveTab "tools"}} class="active"{{end}}>Tools</a>
|
||||
<a href="/debug/clients/{{.AccountID}}"{{if eq .ActiveTab "status"}} class="active"{{end}}>Status</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/syncresponse"{{if eq .ActiveTab "syncresponse"}} class="active"{{end}}>Sync Response</a>
|
||||
</div>
|
||||
<pre>{{.Content}}</pre>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
33
proxy/internal/debug/templates/clients.html
Normal file
33
proxy/internal/debug/templates/clients.html
Normal file
@@ -0,0 +1,33 @@
|
||||
{{define "clients"}}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Clients</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>All Clients</h1>
|
||||
<p class="info">Uptime: {{.Uptime}} | <a href="/debug">← Back</a></p>
|
||||
{{if .Clients}}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
{{else}}
|
||||
<p>No clients connected</p>
|
||||
{{end}}
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
14
proxy/internal/debug/templates/health.html
Normal file
14
proxy/internal/debug/templates/health.html
Normal file
@@ -0,0 +1,14 @@
|
||||
{{define "health"}}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Health</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>OK</h1>
|
||||
<p>Uptime: {{.Uptime}}</p>
|
||||
<p><a href="/debug">← Back</a></p>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
40
proxy/internal/debug/templates/index.html
Normal file
40
proxy/internal/debug/templates/index.html
Normal file
@@ -0,0 +1,40 @@
|
||||
{{define "index"}}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>NetBird Proxy Debug</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>NetBird Proxy Debug</h1>
|
||||
<p class="info">Version: {{.Version}} | Uptime: {{.Uptime}}</p>
|
||||
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
|
||||
{{if .Clients}}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
{{else}}
|
||||
<p>No clients connected</p>
|
||||
{{end}}
|
||||
<h2>Endpoints</h2>
|
||||
<ul>
|
||||
<li><a href="/debug/clients">/debug/clients</a> - all clients detail</li>
|
||||
<li><a href="/debug/health">/debug/health</a> - health check</li>
|
||||
</ul>
|
||||
<p class="info">Add ?format=json or /json suffix for JSON output</p>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
142
proxy/internal/debug/templates/tools.html
Normal file
142
proxy/internal/debug/templates/tools.html
Normal file
@@ -0,0 +1,142 @@
|
||||
{{define "tools"}}
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Client {{.AccountID}} - Tools</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Client: {{.AccountID}}</h1>
|
||||
<div class="nav">
|
||||
<a href="/debug">← Back</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/tools" class="active">Tools</a>
|
||||
<a href="/debug/clients/{{.AccountID}}">Status</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/syncresponse">Sync Response</a>
|
||||
</div>
|
||||
|
||||
<h2>Client Control</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label> </label>
|
||||
<button onclick="startClient()">Start</button>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label> </label>
|
||||
<button onclick="stopClient()">Stop</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="client-result" class="result"></div>
|
||||
|
||||
<h2>Log Level</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label>Level</label>
|
||||
<select id="log-level" style="width: 120px;">
|
||||
<option value="trace">trace</option>
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn" selected>warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label> </label>
|
||||
<button onclick="setLogLevel()">Set Level</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="log-result" class="result"></div>
|
||||
|
||||
<h2>TCP Ping</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label>Host</label>
|
||||
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label>Port</label>
|
||||
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label> </label>
|
||||
<button onclick="doTcpPing()">Connect</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="tcp-result" class="result"></div>
|
||||
|
||||
<script>
|
||||
const accountID = "{{.AccountID}}";
|
||||
|
||||
async function startClient() {
|
||||
const resultDiv = document.getElementById('client-result');
|
||||
resultDiv.innerHTML = '<span class="info">Starting client...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/start');
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function stopClient() {
|
||||
const resultDiv = document.getElementById('client-result');
|
||||
resultDiv.innerHTML = '<span class="info">Stopping client...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/stop');
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function setLogLevel() {
|
||||
const level = document.getElementById('log-level').value;
|
||||
const resultDiv = document.getElementById('log-result');
|
||||
resultDiv.innerHTML = '<span class="info">Setting log level...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/loglevel?level=' + level);
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ Log level set to: ' + data.level + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function doTcpPing() {
|
||||
const host = document.getElementById('tcp-host').value;
|
||||
const port = document.getElementById('tcp-port').value;
|
||||
if (!host || !port) {
|
||||
alert('Host and port required');
|
||||
return;
|
||||
}
|
||||
const resultDiv = document.getElementById('tcp-result');
|
||||
resultDiv.innerHTML = '<span class="info">Connecting...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/pingtcp?host=' + encodeURIComponent(host) + '&port=' + port);
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.host + ':' + data.port + ' connected in ' + data.latency + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.host + ':' + data.port + ': ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
@@ -3,6 +3,8 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
@@ -18,7 +20,7 @@ const (
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
ServiceId string
|
||||
AccountId string
|
||||
AccountId types.AccountID
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
@@ -36,14 +38,14 @@ func (c *CapturedData) GetServiceId() string {
|
||||
}
|
||||
|
||||
// SetAccountId safely sets the account ID
|
||||
func (c *CapturedData) SetAccountId(accountId string) {
|
||||
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AccountId = accountId
|
||||
}
|
||||
|
||||
// GetAccountId safely gets the account ID
|
||||
func (c *CapturedData) GetAccountId() string {
|
||||
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AccountId
|
||||
@@ -76,13 +78,13 @@ func ServiceIdFromContext(ctx context.Context) string {
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
func withAccountId(ctx context.Context, accountId string) context.Context {
|
||||
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIdKey, accountId)
|
||||
}
|
||||
|
||||
func AccountIdFromContext(ctx context.Context) string {
|
||||
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIdKey)
|
||||
accountId, ok := v.(string)
|
||||
accountId, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
@@ -36,8 +38,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), serviceId)
|
||||
// Set the accountId in the context for later retrieval.
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
|
||||
@@ -6,16 +6,18 @@ import (
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, string, bool) {
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
if p.mappings == nil {
|
||||
p.mappingsMux.RUnlock()
|
||||
@@ -27,10 +29,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
|
||||
}
|
||||
defer p.mappingsMux.RUnlock()
|
||||
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
host = req.Host
|
||||
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||
host := req.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
m, exists := p.mappings[host]
|
||||
if !exists {
|
||||
return nil, "", "", false
|
||||
|
||||
@@ -4,18 +4,39 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const deviceNamePrefix = "ingress-"
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
var ErrNoAccountID = errors.New("no account ID in request context")
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
reverseProxyID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
||||
@@ -23,147 +44,389 @@ type statusNotifier interface {
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
logger *log.Logger
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[string]*embed.Client
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
}
|
||||
|
||||
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
// NewNetBird creates a new NetBird transport.
|
||||
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
proxyID: proxyID,
|
||||
logger: logger,
|
||||
clients: make(map[string]*embed.Client),
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created using the provided setup key. Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID}
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n.initLogOnce.Do(func() {
|
||||
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
||||
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
wgPort := 0
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + domain,
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
SetupKey: key,
|
||||
LogOutput: io.Discard,
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
entry = &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
n.logger.Debug("netbird client timed out")
|
||||
// This is not ideal, but we will try again later.
|
||||
return
|
||||
case err != nil:
|
||||
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Debug("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Notify management that tunnel is now active
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}{domain: dom, reverseProxyID: info.reverseProxyID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.clients[domain] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[domain]
|
||||
n.clientsMux.RUnlock()
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
// Mission failed successfully!
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("stop netbird client: %w", err)
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify management that tunnel is disconnected
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
delete(n.clients, domain)
|
||||
transport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
host = req.Host
|
||||
accountID := AccountIDFromContext(req.Context())
|
||||
if accountID == "" {
|
||||
return nil, ErrNoAccountID
|
||||
}
|
||||
|
||||
// Copy references while holding lock, then unlock early to avoid blocking
|
||||
// other requests during the potentially slow RoundTrip.
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[host]
|
||||
// Immediately unlock after retrieval here rather than defer to avoid
|
||||
// the call to client.Do blocking other clients being used whilst one
|
||||
// is in use.
|
||||
n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no peer connection found for host: %s", host)
|
||||
n.clientsMux.RUnlock()
|
||||
return nil, fmt.Errorf("no peer connection found for account: %s", accountID)
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, embed.ErrClientAlreadyStarted):
|
||||
break
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"account_id": accountID,
|
||||
"host": req.Host,
|
||||
"url": req.URL.String(),
|
||||
"requestURI": req.RequestURI,
|
||||
"method": req.Method,
|
||||
}).Debug("running roundtrip for peer connection")
|
||||
|
||||
// Create a new transport using the client dialer and perform the roundtrip.
|
||||
// We do this instead of using the client HTTPClient to avoid issues around
|
||||
// client request validation that do not work with the reverse proxied
|
||||
// requests.
|
||||
// Other values are simply copied from the http.DefaultTransport which the
|
||||
// standard reverse proxy implementation would have used.
|
||||
// TODO: tune this transport for our needs.
|
||||
return (&http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}).RoundTrip(req)
|
||||
return transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// StopAll stops all clients.
|
||||
func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
maps.Clear(n.clients)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// HasClient returns true if there is a client for the given account.
|
||||
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
_, exists := n.clients[accountID]
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
func (n *NetBird) ClientCount() int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
return len(n.clients)
|
||||
}
|
||||
|
||||
// GetClient returns the embed.Client for the given account ID.
|
||||
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
}
|
||||
|
||||
// AccountIDFromContext retrieves the account ID from the context.
|
||||
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIDContextKey{})
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
accountID, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
|
||||
247
proxy/internal/roundtrip/netbird_test.go
Normal file
247
proxy/internal/roundtrip/netbird_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird("http://invalid.test:9999", "test-proxy", nil, nil)
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Initially no client exists.
|
||||
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
|
||||
// Add first domain - this should create a new client.
|
||||
// Note: This will fail to actually connect since we use an invalid URL,
|
||||
// but the client entry should still be created.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||
|
||||
// Add second domain for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||
|
||||
// Add third domain.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||
|
||||
// Still only one client.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
|
||||
// Add domain for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add domain for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both accounts should have their own clients.
|
||||
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add multiple domains.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||
|
||||
// Remove one domain - client should remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||
|
||||
// Remove another domain - client should still remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add single domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
// Remove the only domain - client should be removed.
|
||||
// Note: Stop() may fail since the client never actually connected,
|
||||
// but the entry should still be removed from the map.
|
||||
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
|
||||
// After removing all domains, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Removing from non-existent account should not error.
|
||||
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add one domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove non-existent domain - should not affect existing domain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original domain should still be registered.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||
}
|
||||
|
||||
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := types.AccountID("test-account")
|
||||
|
||||
// Initially no account ID in context.
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should be empty when not set")
|
||||
|
||||
// Add account ID to context.
|
||||
ctx = WithAccountID(ctx, accountID)
|
||||
retrieved = AccountIDFromContext(ctx)
|
||||
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
|
||||
}
|
||||
|
||||
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
|
||||
// Create context with wrong type for account ID key.
|
||||
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
|
||||
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should return empty for wrong type")
|
||||
}
|
||||
|
||||
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
account3 := types.AccountID("account-3")
|
||||
|
||||
// Add domains for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||
|
||||
// Stop all clients.
|
||||
// Note: StopAll may return errors since clients never actually connected,
|
||||
// but the clients should still be removed from the map.
|
||||
_ = nb.StopAll(context.Background())
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||
assert.False(t, nb.HasClient(account1), "account1 should not have client")
|
||||
assert.False(t, nb.HasClient(account2), "account2 should not have client")
|
||||
assert.False(t, nb.HasClient(account3), "account3 should not have client")
|
||||
}
|
||||
|
||||
func TestNetBird_ClientCount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||
|
||||
// Add clients for different accounts.
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.ClientCount())
|
||||
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount())
|
||||
|
||||
// Adding domain to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
// Create a request without account ID in context.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RoundTrip should fail because no account ID in context.
|
||||
_, err = nb.RoundTrip(req)
|
||||
require.ErrorIs(t, err, ErrNoAccountID)
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Create a request with account ID but no client exists.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
req = req.WithContext(WithAccountID(req.Context(), accountID))
|
||||
|
||||
// RoundTrip should fail because no client for this account.
|
||||
_, err = nb.RoundTrip(req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
5
proxy/internal/types/types.go
Normal file
5
proxy/internal/types/types.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package types defines common types used across the proxy package.
|
||||
package types
|
||||
|
||||
// AccountID represents a unique identifier for a NetBird account.
|
||||
type AccountID string
|
||||
@@ -31,8 +31,11 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
@@ -45,6 +48,7 @@ type Server struct {
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
debug *http.Server
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
@@ -62,6 +66,11 @@ type Server struct {
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
|
||||
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||
DebugEndpointAddress string
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -148,7 +157,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger, s)
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.Logger, s)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
@@ -204,6 +213,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
|
||||
|
||||
if s.DebugEndpointEnabled {
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.Logger)
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
}
|
||||
go func() {
|
||||
s.Logger.WithField("address", debugAddr).Info("starting debug endpoint")
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := s.debug.Close(); err != nil {
|
||||
s.Logger.Debugf("debug endpoint close: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.netbird.StopAll(stopCtx); err != nil {
|
||||
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Finally, start the reverse proxy.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
@@ -306,15 +343,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
domain := mapping.GetDomain()
|
||||
accountID := mapping.GetAccountId()
|
||||
d := domain.Domain(mapping.GetDomain())
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
reverseProxyID := mapping.GetId()
|
||||
|
||||
if err := s.netbird.AddPeer(ctx, domain, mapping.GetSetupKey(), accountID, reverseProxyID); err != nil {
|
||||
return fmt.Errorf("create peer for domain %q: %w", domain, err)
|
||||
if err := s.netbird.AddPeer(ctx, accountID, d, mapping.GetSetupKey(), reverseProxyID); err != nil {
|
||||
return fmt.Errorf("create peer for domain %q: %w", d, err)
|
||||
}
|
||||
if s.acme != nil {
|
||||
s.acme.AddDomain(domain, accountID, reverseProxyID)
|
||||
s.acme.AddDomain(string(d), string(accountID), reverseProxyID)
|
||||
}
|
||||
|
||||
// Pass the mapping through to the update function to avoid duplicating the
|
||||
@@ -360,10 +397,13 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain(), mapping.GetAccountId(), mapping.GetId()); err != nil {
|
||||
d := domain.Domain(mapping.GetDomain())
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"error": err,
|
||||
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
|
||||
}
|
||||
if s.acme != nil {
|
||||
@@ -392,8 +432,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
}
|
||||
return proxy.Mapping{
|
||||
ID: mapping.GetId(),
|
||||
AccountID: mapping.AccountId,
|
||||
AccountID: types.AccountID(mapping.GetAccountId()),
|
||||
Host: mapping.GetDomain(),
|
||||
Paths: paths,
|
||||
}
|
||||
}
|
||||
|
||||
// debugEndpointAddr returns the address for the debug endpoint.
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return "localhost:8444"
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
48
proxy/server_test.go
Normal file
48
proxy/server_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDebugEndpointDisabledByDefault(t *testing.T) {
|
||||
s := &Server{}
|
||||
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
|
||||
}
|
||||
|
||||
func TestDebugEndpointAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty defaults to localhost",
|
||||
input: "",
|
||||
expected: "localhost:8444",
|
||||
},
|
||||
{
|
||||
name: "explicit localhost preserved",
|
||||
input: "localhost:9999",
|
||||
expected: "localhost:9999",
|
||||
},
|
||||
{
|
||||
name: "explicit address preserved",
|
||||
input: "0.0.0.0:8444",
|
||||
expected: "0.0.0.0:8444",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 preserved",
|
||||
input: "127.0.0.1:8444",
|
||||
expected: "127.0.0.1:8444",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := debugEndpointAddr(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user