mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients - Add types package with AccountID type - Refactor netbird roundtrip to key clients by AccountID - Multiple domains can share the same client per account - Add status notifier for tunnel connection updates - Add OIDC flags to CLI - Add tests for netbird client management
This commit is contained in:
166
proxy/cmd/proxy/cmd/debug.go
Normal file
166
proxy/cmd/proxy/cmd/debug.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugAddr string
|
||||||
|
jsonOutput bool
|
||||||
|
|
||||||
|
// status filters
|
||||||
|
statusFilterByIPs []string
|
||||||
|
statusFilterByNames []string
|
||||||
|
statusFilterByStatus string
|
||||||
|
statusFilterByConnectionType string
|
||||||
|
)
|
||||||
|
|
||||||
|
var debugCmd = &cobra.Command{
|
||||||
|
Use: "debug",
|
||||||
|
Short: "Debug commands for inspecting proxy state",
|
||||||
|
Long: "Debug commands for inspecting the reverse proxy state via the debug HTTP endpoint.",
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugHealthCmd = &cobra.Command{
|
||||||
|
Use: "health",
|
||||||
|
Short: "Show proxy health status",
|
||||||
|
RunE: runDebugHealth,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugClientsCmd = &cobra.Command{
|
||||||
|
Use: "clients",
|
||||||
|
Aliases: []string{"list"},
|
||||||
|
Short: "List all connected clients",
|
||||||
|
RunE: runDebugClients,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugStatusCmd = &cobra.Command{
|
||||||
|
Use: "status <account-id>",
|
||||||
|
Short: "Show client status",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: runDebugStatus,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugSyncCmd = &cobra.Command{
|
||||||
|
Use: "sync-response <account-id>",
|
||||||
|
Short: "Show client sync response",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: runDebugSync,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var pingTimeout string
|
||||||
|
|
||||||
|
var debugPingCmd = &cobra.Command{
|
||||||
|
Use: "ping <account-id> <host> [port]",
|
||||||
|
Short: "TCP ping through a client",
|
||||||
|
Long: "Perform a TCP ping through a client's network to test connectivity.\nPort defaults to 80 if not specified.",
|
||||||
|
Args: cobra.RangeArgs(2, 3),
|
||||||
|
RunE: runDebugPing,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugLogLevelCmd = &cobra.Command{
|
||||||
|
Use: "loglevel <account-id> <level>",
|
||||||
|
Short: "Set client log level",
|
||||||
|
Long: "Set the log level for a client (trace, debug, info, warn, error).",
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: runDebugLogLevel,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugStartCmd = &cobra.Command{
|
||||||
|
Use: "start <account-id>",
|
||||||
|
Short: "Start a client",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: runDebugStart,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugStopCmd = &cobra.Command{
|
||||||
|
Use: "stop <account-id>",
|
||||||
|
Short: "Stop a client",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: runDebugStop,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||||
|
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||||
|
|
||||||
|
debugStatusCmd.Flags().StringSliceVar(&statusFilterByIPs, "filter-by-ips", nil, "Filter by peer IPs (comma-separated)")
|
||||||
|
debugStatusCmd.Flags().StringSliceVar(&statusFilterByNames, "filter-by-names", nil, "Filter by peer names (comma-separated)")
|
||||||
|
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
|
||||||
|
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
|
||||||
|
|
||||||
|
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||||
|
|
||||||
|
debugCmd.AddCommand(debugHealthCmd)
|
||||||
|
debugCmd.AddCommand(debugClientsCmd)
|
||||||
|
debugCmd.AddCommand(debugStatusCmd)
|
||||||
|
debugCmd.AddCommand(debugSyncCmd)
|
||||||
|
debugCmd.AddCommand(debugPingCmd)
|
||||||
|
debugCmd.AddCommand(debugLogLevelCmd)
|
||||||
|
debugCmd.AddCommand(debugStartCmd)
|
||||||
|
debugCmd.AddCommand(debugStopCmd)
|
||||||
|
|
||||||
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDebugClient(cmd *cobra.Command) *debug.Client {
|
||||||
|
return debug.NewClient(debugAddr, jsonOutput, cmd.OutOrStdout())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugHealth(cmd *cobra.Command, _ []string) error {
|
||||||
|
return getDebugClient(cmd).Health(cmd.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugClients(cmd *cobra.Command, _ []string) error {
|
||||||
|
return getDebugClient(cmd).ListClients(cmd.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugStatus(cmd *cobra.Command, args []string) error {
|
||||||
|
return getDebugClient(cmd).ClientStatus(cmd.Context(), args[0], debug.StatusFilters{
|
||||||
|
IPs: statusFilterByIPs,
|
||||||
|
Names: statusFilterByNames,
|
||||||
|
Status: statusFilterByStatus,
|
||||||
|
ConnectionType: statusFilterByConnectionType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugSync(cmd *cobra.Command, args []string) error {
|
||||||
|
return getDebugClient(cmd).ClientSyncResponse(cmd.Context(), args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugPing(cmd *cobra.Command, args []string) error {
|
||||||
|
port := 80
|
||||||
|
if len(args) > 2 {
|
||||||
|
p, err := strconv.Atoi(args[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
port = p
|
||||||
|
}
|
||||||
|
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
|
||||||
|
return getDebugClient(cmd).SetLogLevel(cmd.Context(), args[0], args[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||||
|
return getDebugClient(cmd).StartClient(cmd.Context(), args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||||
|
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||||
|
}
|
||||||
137
proxy/cmd/proxy/cmd/root.go
Normal file
137
proxy/cmd/proxy/cmd/root.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/acme"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||||
|
|
||||||
|
var (
|
||||||
|
Version = "dev"
|
||||||
|
Commit = "unknown"
|
||||||
|
BuildDate = "unknown"
|
||||||
|
GoVersion = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugLogs bool
|
||||||
|
mgmtAddr string
|
||||||
|
addr string
|
||||||
|
proxyURL string
|
||||||
|
certDir string
|
||||||
|
acmeCerts bool
|
||||||
|
acmeAddr string
|
||||||
|
acmeDir string
|
||||||
|
debugEndpoint bool
|
||||||
|
debugEndpointAddr string
|
||||||
|
oidcClientID string
|
||||||
|
oidcClientSecret string
|
||||||
|
oidcEndpoint string
|
||||||
|
oidcScopes string
|
||||||
|
)
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "proxy",
|
||||||
|
Short: "NetBird reverse proxy server",
|
||||||
|
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
|
||||||
|
Version: Version,
|
||||||
|
RunE: runServer,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
||||||
|
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
|
||||||
|
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
|
||||||
|
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached")
|
||||||
|
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
|
||||||
|
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
|
||||||
|
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
|
||||||
|
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
|
||||||
|
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
|
||||||
|
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
|
||||||
|
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
||||||
|
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
||||||
|
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
||||||
|
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs the root command.
|
||||||
|
func Execute() {
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVersionInfo sets version information for the CLI.
|
||||||
|
func SetVersionInfo(version, commit, buildDate, goVersion string) {
|
||||||
|
Version = version
|
||||||
|
Commit = commit
|
||||||
|
BuildDate = buildDate
|
||||||
|
GoVersion = goVersion
|
||||||
|
rootCmd.Version = version
|
||||||
|
rootCmd.SetVersionTemplate("Version: {{.Version}}, Commit: " + Commit + ", BuildDate: " + BuildDate + ", Go: " + GoVersion + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServer(cmd *cobra.Command, args []string) error {
|
||||||
|
level := "error"
|
||||||
|
if debugLogs {
|
||||||
|
level = "debug"
|
||||||
|
}
|
||||||
|
logger := log.New()
|
||||||
|
|
||||||
|
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||||
|
|
||||||
|
log.Infof("configured log level: %s", level)
|
||||||
|
|
||||||
|
srv := proxy.Server{
|
||||||
|
Logger: logger,
|
||||||
|
Version: Version,
|
||||||
|
ManagementAddress: mgmtAddr,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
CertificateDirectory: certDir,
|
||||||
|
GenerateACMECertificates: acmeCerts,
|
||||||
|
ACMEChallengeAddress: acmeAddr,
|
||||||
|
ACMEDirectory: acmeDir,
|
||||||
|
DebugEndpointEnabled: debugEndpoint,
|
||||||
|
DebugEndpointAddress: debugEndpointAddr,
|
||||||
|
OIDCClientId: oidcClientID,
|
||||||
|
OIDCClientSecret: oidcClientSecret,
|
||||||
|
OIDCEndpoint: oidcEndpoint,
|
||||||
|
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func envBoolOrDefault(key string, def bool) bool {
|
||||||
|
v, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
parsed, err := strconv.ParseBool(v)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func envStringOrDefault(key string, def string) string {
|
||||||
|
v, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
@@ -1,22 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/proxy/cmd/proxy/cmd"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/acme"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Version is the application version (set via ldflags during build)
|
// Version is the application version (set via ldflags during build)
|
||||||
Version = "dev"
|
Version = "dev"
|
||||||
@@ -31,78 +20,7 @@ var (
|
|||||||
GoVersion = runtime.Version()
|
GoVersion = runtime.Version()
|
||||||
)
|
)
|
||||||
|
|
||||||
func envBoolOrDefault(key string, def bool) bool {
|
|
||||||
v, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return v == strings.ToLower("true")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envStringOrDefault(key string, def string) string {
|
|
||||||
v, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion)
|
||||||
version, debug bool
|
cmd.Execute()
|
||||||
mgmtAddr, addr, url, certDir string
|
|
||||||
acmeCerts bool
|
|
||||||
acmeAddr, acmeDir string
|
|
||||||
oidcId, oidcSecret, oidcEndpoint, oidcScopes string
|
|
||||||
)
|
|
||||||
|
|
||||||
flag.BoolVar(&version, "v", false, "Print version and exit")
|
|
||||||
flag.BoolVar(&debug, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
|
||||||
flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.")
|
|
||||||
flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.")
|
|
||||||
flag.StringVar(&url, "url", envStringOrDefault("NB_PROXY_URL", "proxy.netbird.io"), "The URL at which this proxy will be reached, where CNAME records for proxied endpoints will be directed.")
|
|
||||||
flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ")
|
|
||||||
flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.")
|
|
||||||
flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.")
|
|
||||||
flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.")
|
|
||||||
flag.StringVar(&oidcId, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
|
||||||
flag.StringVar(&oidcSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
|
||||||
flag.StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
|
||||||
flag.StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if version {
|
|
||||||
fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion)
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure logrus.
|
|
||||||
level := "error"
|
|
||||||
if debug {
|
|
||||||
level = "debug"
|
|
||||||
}
|
|
||||||
logger := log.New()
|
|
||||||
|
|
||||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
|
||||||
|
|
||||||
log.Infof("configured log level: %s", level)
|
|
||||||
|
|
||||||
srv := proxy.Server{
|
|
||||||
Logger: logger,
|
|
||||||
Version: Version,
|
|
||||||
ManagementAddress: mgmtAddr,
|
|
||||||
ProxyURL: url,
|
|
||||||
CertificateDirectory: certDir,
|
|
||||||
GenerateACMECertificates: acmeCerts,
|
|
||||||
ACMEChallengeAddress: acmeAddr,
|
|
||||||
ACMEDirectory: acmeDir,
|
|
||||||
OIDCClientId: oidcId,
|
|
||||||
OIDCClientSecret: oidcSecret,
|
|
||||||
OIDCEndpoint: oidcEndpoint,
|
|
||||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
entry := logEntry{
|
entry := logEntry{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
ServiceId: capturedData.GetServiceId(),
|
ServiceId: capturedData.GetServiceId(),
|
||||||
AccountID: capturedData.GetAccountId(),
|
AccountID: string(capturedData.GetAccountId()),
|
||||||
Host: host,
|
Host: host,
|
||||||
Path: r.URL.Path,
|
Path: r.URL.Path,
|
||||||
DurationMs: duration.Milliseconds(),
|
DurationMs: duration.Milliseconds(),
|
||||||
|
|||||||
307
proxy/internal/debug/client.go
Normal file
307
proxy/internal/debug/client.go
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
// Package debug provides HTTP debug endpoints and CLI client for the proxy server.
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StatusFilters contains filter options for status queries.
|
||||||
|
type StatusFilters struct {
|
||||||
|
IPs []string
|
||||||
|
Names []string
|
||||||
|
Status string
|
||||||
|
ConnectionType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client provides CLI access to debug endpoints.
|
||||||
|
type Client struct {
|
||||||
|
baseURL string
|
||||||
|
jsonOutput bool
|
||||||
|
httpClient *http.Client
|
||||||
|
out io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new debug client.
|
||||||
|
func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client {
|
||||||
|
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||||
|
baseURL = "http://" + baseURL
|
||||||
|
}
|
||||||
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
baseURL: baseURL,
|
||||||
|
jsonOutput: jsonOutput,
|
||||||
|
out: out,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Health fetches the health status.
|
||||||
|
func (c *Client) Health(ctx context.Context) error {
|
||||||
|
return c.fetchAndPrint(ctx, "/debug/health", c.printHealth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printHealth(data map[string]any) {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"])
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListClients fetches the list of all clients.
|
||||||
|
func (c *Client) ListClients(ctx context.Context) error {
|
||||||
|
return c.fetchAndPrint(ctx, "/debug/clients", c.printClients)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printClients(data map[string]any) {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"])
|
||||||
|
|
||||||
|
clients, ok := data["clients"].([]any)
|
||||||
|
if !ok || len(clients) == 0 {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "No clients connected.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
|
||||||
|
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
|
||||||
|
|
||||||
|
for _, item := range clients {
|
||||||
|
c.printClientRow(item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printClientRow(item any) {
|
||||||
|
client, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domains := c.extractDomains(client)
|
||||||
|
hasClient := "no"
|
||||||
|
if hc, ok := client["has_client"].(bool); ok && hc {
|
||||||
|
hasClient = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
|
||||||
|
client["account_id"],
|
||||||
|
client["age"],
|
||||||
|
domains,
|
||||||
|
hasClient,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) extractDomains(client map[string]any) string {
|
||||||
|
d, ok := client["domains"].([]any)
|
||||||
|
if !ok || len(d) == 0 {
|
||||||
|
return "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, len(d))
|
||||||
|
for i, domain := range d {
|
||||||
|
parts[i] = fmt.Sprint(domain)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientStatus fetches the status of a specific client.
|
||||||
|
func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error {
|
||||||
|
params := url.Values{}
|
||||||
|
if len(filters.IPs) > 0 {
|
||||||
|
params.Set("filter-by-ips", strings.Join(filters.IPs, ","))
|
||||||
|
}
|
||||||
|
if len(filters.Names) > 0 {
|
||||||
|
params.Set("filter-by-names", strings.Join(filters.Names, ","))
|
||||||
|
}
|
||||||
|
if filters.Status != "" {
|
||||||
|
params.Set("filter-by-status", filters.Status)
|
||||||
|
}
|
||||||
|
if filters.ConnectionType != "" {
|
||||||
|
params.Set("filter-by-connection-type", filters.ConnectionType)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/debug/clients/" + url.PathEscape(accountID)
|
||||||
|
if len(params) > 0 {
|
||||||
|
path += "?" + params.Encode()
|
||||||
|
}
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printClientStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printClientStatus(data map[string]any) {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
|
||||||
|
if status, ok := data["status"].(string); ok {
|
||||||
|
_, _ = fmt.Fprint(c.out, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientSyncResponse fetches the sync response of a specific client.
|
||||||
|
func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error {
|
||||||
|
path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse"
|
||||||
|
return c.fetchAndPrintJSON(ctx, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingTCP performs a TCP ping through a client.
|
||||||
|
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("host", host)
|
||||||
|
params.Set("port", fmt.Sprintf("%d", port))
|
||||||
|
if timeout != "" {
|
||||||
|
params.Set("timeout", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printPingResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printPingResult(data map[string]any) {
|
||||||
|
success, _ := data["success"].(bool)
|
||||||
|
if success {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
||||||
|
c.printError(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level of a specific client.
|
||||||
|
func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("level", level)
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode())
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printLogLevelResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printLogLevelResult(data map[string]any) {
|
||||||
|
success, _ := data["success"].(bool)
|
||||||
|
if success {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"])
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Failed to set log level")
|
||||||
|
c.printError(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartClient starts a specific client.
|
||||||
|
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||||
|
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printStartResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printStartResult(data map[string]any) {
|
||||||
|
success, _ := data["success"].(bool)
|
||||||
|
if success {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Client started")
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Failed to start client")
|
||||||
|
c.printError(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopClient stops a specific client.
|
||||||
|
func (c *Client) StopClient(ctx context.Context, accountID string) error {
|
||||||
|
path := "/debug/clients/" + url.PathEscape(accountID) + "/stop"
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printStopResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printStopResult(data map[string]any) {
|
||||||
|
success, _ := data["success"].(bool)
|
||||||
|
if success {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Client stopped")
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Failed to stop client")
|
||||||
|
c.printError(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printError(data map[string]any) {
|
||||||
|
if errMsg, ok := data["error"].(string); ok {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||||
|
data, raw, err := c.fetch(ctx, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.jsonOutput {
|
||||||
|
return c.writeJSON(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data != nil {
|
||||||
|
printer(data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error {
|
||||||
|
data, raw, err := c.fetch(ctx, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if data != nil {
|
||||||
|
return c.writeJSON(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) writeJSON(data map[string]any) error {
|
||||||
|
enc := json.NewEncoder(c.out)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
return enc.Encode(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) {
|
||||||
|
fullURL := c.baseURL + path
|
||||||
|
if !strings.Contains(path, "format=json") {
|
||||||
|
if strings.Contains(path, "?") {
|
||||||
|
fullURL += "&format=json"
|
||||||
|
} else {
|
||||||
|
fullURL += "?format=json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
return nil, body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, body, nil
|
||||||
|
}
|
||||||
|
|
||||||
589
proxy/internal/debug/handler.go
Normal file
589
proxy/internal/debug/handler.go
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
// Package debug provides HTTP debug endpoints for the proxy server.
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
|
nbembed "github.com/netbirdio/netbird/client/embed"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed templates/*.html
|
||||||
|
var templateFS embed.FS
|
||||||
|
|
||||||
|
const defaultPingTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// formatDuration formats a duration with 2 decimal places using appropriate units.
|
||||||
|
func formatDuration(d time.Duration) string {
|
||||||
|
switch {
|
||||||
|
case d >= time.Hour:
|
||||||
|
return fmt.Sprintf("%.2fh", d.Hours())
|
||||||
|
case d >= time.Minute:
|
||||||
|
return fmt.Sprintf("%.2fm", d.Minutes())
|
||||||
|
case d >= time.Second:
|
||||||
|
return fmt.Sprintf("%.2fs", d.Seconds())
|
||||||
|
case d >= time.Millisecond:
|
||||||
|
return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000)
|
||||||
|
case d >= time.Microsecond:
|
||||||
|
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%dns", d.Nanoseconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientProvider provides access to NetBird clients.
|
||||||
|
type ClientProvider interface {
|
||||||
|
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||||
|
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler provides HTTP debug endpoints.
|
||||||
|
type Handler struct {
|
||||||
|
provider ClientProvider
|
||||||
|
logger *log.Logger
|
||||||
|
startTime time.Time
|
||||||
|
templates *template.Template
|
||||||
|
templateMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a new debug handler.
|
||||||
|
func NewHandler(provider ClientProvider, logger *log.Logger) *Handler {
|
||||||
|
if logger == nil {
|
||||||
|
logger = log.StandardLogger()
|
||||||
|
}
|
||||||
|
h := &Handler{
|
||||||
|
provider: provider,
|
||||||
|
logger: logger,
|
||||||
|
startTime: time.Now(),
|
||||||
|
}
|
||||||
|
if err := h.loadTemplates(); err != nil {
|
||||||
|
logger.Errorf("failed to load embedded templates: %v", err)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) loadTemplates() error {
|
||||||
|
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse embedded templates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.templateMu.Lock()
|
||||||
|
h.templates = tmpl
|
||||||
|
h.templateMu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getTemplates() *template.Template {
|
||||||
|
h.templateMu.RLock()
|
||||||
|
defer h.templateMu.RUnlock()
|
||||||
|
return h.templates
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP handles debug requests.
|
||||||
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
path := r.URL.Path
|
||||||
|
wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json")
|
||||||
|
path = strings.TrimSuffix(path, "/json")
|
||||||
|
|
||||||
|
switch path {
|
||||||
|
case "/debug", "/debug/":
|
||||||
|
h.handleIndex(w, r, wantJSON)
|
||||||
|
case "/debug/clients":
|
||||||
|
h.handleListClients(w, r, wantJSON)
|
||||||
|
case "/debug/health":
|
||||||
|
h.handleHealth(w, r, wantJSON)
|
||||||
|
default:
|
||||||
|
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool {
|
||||||
|
if !strings.HasPrefix(path, "/debug/clients/") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
rest := strings.TrimPrefix(path, "/debug/clients/")
|
||||||
|
parts := strings.SplitN(rest, "/", 2)
|
||||||
|
accountID := types.AccountID(parts[0])
|
||||||
|
|
||||||
|
if len(parts) == 1 {
|
||||||
|
h.handleClientStatus(w, r, accountID, wantJSON)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parts[1] {
|
||||||
|
case "syncresponse":
|
||||||
|
h.handleClientSyncResponse(w, r, accountID, wantJSON)
|
||||||
|
case "tools":
|
||||||
|
h.handleClientTools(w, r, accountID)
|
||||||
|
case "pingtcp":
|
||||||
|
h.handlePingTCP(w, r, accountID)
|
||||||
|
case "loglevel":
|
||||||
|
h.handleLogLevel(w, r, accountID)
|
||||||
|
case "start":
|
||||||
|
h.handleClientStart(w, r, accountID)
|
||||||
|
case "stop":
|
||||||
|
h.handleClientStop(w, r, accountID)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type indexData struct {
|
||||||
|
Version string
|
||||||
|
Uptime string
|
||||||
|
ClientCount int
|
||||||
|
TotalDomains int
|
||||||
|
Clients []clientData
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientData struct {
|
||||||
|
AccountID string
|
||||||
|
Domains string
|
||||||
|
Age string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||||
|
clients := h.provider.ListClientsForDebug()
|
||||||
|
|
||||||
|
totalDomains := 0
|
||||||
|
for _, info := range clients {
|
||||||
|
totalDomains += info.DomainCount
|
||||||
|
}
|
||||||
|
|
||||||
|
if wantJSON {
|
||||||
|
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||||
|
for _, info := range clients {
|
||||||
|
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||||
|
"account_id": info.AccountID,
|
||||||
|
"domain_count": info.DomainCount,
|
||||||
|
"domains": info.Domains,
|
||||||
|
"has_client": info.HasClient,
|
||||||
|
"created_at": info.CreatedAt,
|
||||||
|
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"version": version.NetbirdVersion(),
|
||||||
|
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
"client_count": len(clients),
|
||||||
|
"total_domains": totalDomains,
|
||||||
|
"clients": clientsJSON,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := indexData{
|
||||||
|
Version: version.NetbirdVersion(),
|
||||||
|
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
ClientCount: len(clients),
|
||||||
|
TotalDomains: totalDomains,
|
||||||
|
Clients: make([]clientData, 0, len(clients)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, info := range clients {
|
||||||
|
domains := info.Domains.SafeString()
|
||||||
|
if domains == "" {
|
||||||
|
domains = "-"
|
||||||
|
}
|
||||||
|
status := "No client"
|
||||||
|
if info.HasClient {
|
||||||
|
status = "Active"
|
||||||
|
}
|
||||||
|
data.Clients = append(data.Clients, clientData{
|
||||||
|
AccountID: string(info.AccountID),
|
||||||
|
Domains: domains,
|
||||||
|
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||||
|
Status: status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "index", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientsData struct {
|
||||||
|
Uptime string
|
||||||
|
Clients []clientData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||||
|
clients := h.provider.ListClientsForDebug()
|
||||||
|
|
||||||
|
if wantJSON {
|
||||||
|
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||||
|
for _, info := range clients {
|
||||||
|
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||||
|
"account_id": info.AccountID,
|
||||||
|
"domain_count": info.DomainCount,
|
||||||
|
"domains": info.Domains,
|
||||||
|
"has_client": info.HasClient,
|
||||||
|
"created_at": info.CreatedAt,
|
||||||
|
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
"client_count": len(clients),
|
||||||
|
"clients": clientsJSON,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := clientsData{
|
||||||
|
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
Clients: make([]clientData, 0, len(clients)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, info := range clients {
|
||||||
|
domains := info.Domains.SafeString()
|
||||||
|
if domains == "" {
|
||||||
|
domains = "-"
|
||||||
|
}
|
||||||
|
status := "No client"
|
||||||
|
if info.HasClient {
|
||||||
|
status = "Active"
|
||||||
|
}
|
||||||
|
data.Clients = append(data.Clients, clientData{
|
||||||
|
AccountID: string(info.AccountID),
|
||||||
|
Domains: domains,
|
||||||
|
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||||
|
Status: status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "clients", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientDetailData struct {
|
||||||
|
AccountID string
|
||||||
|
ActiveTab string
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus, err := client.Status()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse filter parameters
|
||||||
|
query := r.URL.Query()
|
||||||
|
statusFilter := query.Get("filter-by-status")
|
||||||
|
connectionTypeFilter := query.Get("filter-by-connection-type")
|
||||||
|
|
||||||
|
var prefixNamesFilter []string
|
||||||
|
var prefixNamesFilterMap map[string]struct{}
|
||||||
|
if names := query.Get("filter-by-names"); names != "" {
|
||||||
|
prefixNamesFilter = strings.Split(names, ",")
|
||||||
|
prefixNamesFilterMap = make(map[string]struct{})
|
||||||
|
for _, name := range prefixNamesFilter {
|
||||||
|
prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipsFilterMap map[string]struct{}
|
||||||
|
if ips := query.Get("filter-by-ips"); ips != "" {
|
||||||
|
ipsFilterMap = make(map[string]struct{})
|
||||||
|
for _, ip := range strings.Split(ips, ",") {
|
||||||
|
ipsFilterMap[strings.TrimSpace(ip)] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pbStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||||
|
overview := nbstatus.ConvertToStatusOutputOverview(
|
||||||
|
pbStatus,
|
||||||
|
false,
|
||||||
|
version.NetbirdVersion(),
|
||||||
|
statusFilter,
|
||||||
|
prefixNamesFilter,
|
||||||
|
prefixNamesFilterMap,
|
||||||
|
ipsFilterMap,
|
||||||
|
connectionTypeFilter,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
if wantJSON {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"account_id": accountID,
|
||||||
|
"status": overview.FullDetailSummary(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := clientDetailData{
|
||||||
|
AccountID: string(accountID),
|
||||||
|
ActiveTab: "status",
|
||||||
|
Content: overview.FullDetailSummary(),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "clientDetail", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
syncResp, err := client.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if syncResp == nil {
|
||||||
|
http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := protojson.MarshalOptions{
|
||||||
|
EmitUnpopulated: true,
|
||||||
|
UseProtoNames: true,
|
||||||
|
Indent: " ",
|
||||||
|
AllowPartial: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := opts.Marshal(syncResp)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if wantJSON {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write(jsonBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := clientDetailData{
|
||||||
|
AccountID: string(accountID),
|
||||||
|
ActiveTab: "syncresponse",
|
||||||
|
Content: string(jsonBytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "clientDetail", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolsData struct {
|
||||||
|
AccountID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) {
|
||||||
|
_, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := toolsData{
|
||||||
|
AccountID: string(accountID),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "tools", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host := r.URL.Query().Get("host")
|
||||||
|
portStr := r.URL.Query().Get("port")
|
||||||
|
if host == "" || portStr == "" {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil || port < 1 || port > 65535 {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := defaultPingTimeout
|
||||||
|
if t := r.URL.Query().Get("timeout"); t != "" {
|
||||||
|
if d, err := time.ParseDuration(t); err == nil {
|
||||||
|
timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
address := fmt.Sprintf("%s:%d", host, port)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
conn, err := client.Dial(ctx, "tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": false,
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
latency := time.Since(start)
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"latency_ms": latency.Milliseconds(),
|
||||||
|
"latency": formatDuration(latency),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
level := r.URL.Query().Get("level")
|
||||||
|
if level == "" {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.SetLogLevel(level); err != nil {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"level": level,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const clientActionTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.Start(ctx); err != nil {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"message": "client started",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
|
client, ok := h.provider.GetClient(accountID)
|
||||||
|
if !ok {
|
||||||
|
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.Stop(ctx); err != nil {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"message": "client stopped",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type healthData struct {
|
||||||
|
Uptime string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleHealth(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||||
|
if wantJSON {
|
||||||
|
h.writeJSON(w, map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"uptime": time.Since(h.startTime).Round(10 * time.Millisecond).String(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := healthData{
|
||||||
|
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.renderTemplate(w, "health", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
tmpl := h.getTemplates()
|
||||||
|
if tmpl == nil {
|
||||||
|
http.Error(w, "Templates not loaded", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
|
||||||
|
h.logger.Errorf("execute template %s: %v", name, err)
|
||||||
|
http.Error(w, "Template error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(v); err != nil {
|
||||||
|
h.logger.Errorf("encode JSON response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
101
proxy/internal/debug/templates/base.html
Normal file
101
proxy/internal/debug/templates/base.html
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
{{define "style"}}
|
||||||
|
body {
|
||||||
|
font-family: monospace;
|
||||||
|
margin: 20px;
|
||||||
|
background: #1a1a1a;
|
||||||
|
color: #eee;
|
||||||
|
}
|
||||||
|
a {
|
||||||
|
color: #6cf;
|
||||||
|
}
|
||||||
|
h1, h2, h3 {
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
.info {
|
||||||
|
color: #aaa;
|
||||||
|
}
|
||||||
|
table {
|
||||||
|
border-collapse: collapse;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
th, td {
|
||||||
|
border: 1px solid #444;
|
||||||
|
padding: 8px;
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
th {
|
||||||
|
background: #333;
|
||||||
|
}
|
||||||
|
.nav {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.nav a {
|
||||||
|
margin-right: 15px;
|
||||||
|
padding: 8px 16px;
|
||||||
|
background: #333;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.nav a.active {
|
||||||
|
background: #6cf;
|
||||||
|
color: #000;
|
||||||
|
}
|
||||||
|
pre {
|
||||||
|
background: #222;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow-x: auto;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
input, select, textarea {
|
||||||
|
background: #333;
|
||||||
|
color: #eee;
|
||||||
|
border: 1px solid #555;
|
||||||
|
padding: 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
input:focus, select:focus, textarea:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #6cf;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
background: #6cf;
|
||||||
|
color: #000;
|
||||||
|
border: none;
|
||||||
|
padding: 8px 16px;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background: #5be;
|
||||||
|
}
|
||||||
|
button:disabled {
|
||||||
|
background: #555;
|
||||||
|
color: #888;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.form-group label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
color: #aaa;
|
||||||
|
}
|
||||||
|
.form-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
align-items: flex-end;
|
||||||
|
}
|
||||||
|
.result {
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
.success {
|
||||||
|
color: #5f5;
|
||||||
|
}
|
||||||
|
.error {
|
||||||
|
color: #f55;
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
19
proxy/internal/debug/templates/client_detail.html
Normal file
19
proxy/internal/debug/templates/client_detail.html
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{{define "clientDetail"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Client {{.AccountID}}</title>
|
||||||
|
<style>{{template "style"}}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Client: {{.AccountID}}</h1>
|
||||||
|
<div class="nav">
|
||||||
|
<a href="/debug">← Back</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}/tools"{{if eq .ActiveTab "tools"}} class="active"{{end}}>Tools</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}"{{if eq .ActiveTab "status"}} class="active"{{end}}>Status</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}/syncresponse"{{if eq .ActiveTab "syncresponse"}} class="active"{{end}}>Sync Response</a>
|
||||||
|
</div>
|
||||||
|
<pre>{{.Content}}</pre>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}
|
||||||
33
proxy/internal/debug/templates/clients.html
Normal file
33
proxy/internal/debug/templates/clients.html
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{{define "clients"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Clients</title>
|
||||||
|
<style>{{template "style"}}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>All Clients</h1>
|
||||||
|
<p class="info">Uptime: {{.Uptime}} | <a href="/debug">← Back</a></p>
|
||||||
|
{{if .Clients}}
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<th>Account ID</th>
|
||||||
|
<th>Domains</th>
|
||||||
|
<th>Age</th>
|
||||||
|
<th>Status</th>
|
||||||
|
</tr>
|
||||||
|
{{range .Clients}}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||||
|
<td>{{.Domains}}</td>
|
||||||
|
<td>{{.Age}}</td>
|
||||||
|
<td>{{.Status}}</td>
|
||||||
|
</tr>
|
||||||
|
{{end}}
|
||||||
|
</table>
|
||||||
|
{{else}}
|
||||||
|
<p>No clients connected</p>
|
||||||
|
{{end}}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}
|
||||||
14
proxy/internal/debug/templates/health.html
Normal file
14
proxy/internal/debug/templates/health.html
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{{define "health"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Health</title>
|
||||||
|
<style>{{template "style"}}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>OK</h1>
|
||||||
|
<p>Uptime: {{.Uptime}}</p>
|
||||||
|
<p><a href="/debug">← Back</a></p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}
|
||||||
40
proxy/internal/debug/templates/index.html
Normal file
40
proxy/internal/debug/templates/index.html
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
{{define "index"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>NetBird Proxy Debug</title>
|
||||||
|
<style>{{template "style"}}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>NetBird Proxy Debug</h1>
|
||||||
|
<p class="info">Version: {{.Version}} | Uptime: {{.Uptime}}</p>
|
||||||
|
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
|
||||||
|
{{if .Clients}}
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<th>Account ID</th>
|
||||||
|
<th>Domains</th>
|
||||||
|
<th>Age</th>
|
||||||
|
<th>Status</th>
|
||||||
|
</tr>
|
||||||
|
{{range .Clients}}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||||
|
<td>{{.Domains}}</td>
|
||||||
|
<td>{{.Age}}</td>
|
||||||
|
<td>{{.Status}}</td>
|
||||||
|
</tr>
|
||||||
|
{{end}}
|
||||||
|
</table>
|
||||||
|
{{else}}
|
||||||
|
<p>No clients connected</p>
|
||||||
|
{{end}}
|
||||||
|
<h2>Endpoints</h2>
|
||||||
|
<ul>
|
||||||
|
<li><a href="/debug/clients">/debug/clients</a> - all clients detail</li>
|
||||||
|
<li><a href="/debug/health">/debug/health</a> - health check</li>
|
||||||
|
</ul>
|
||||||
|
<p class="info">Add ?format=json or /json suffix for JSON output</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}
|
||||||
142
proxy/internal/debug/templates/tools.html
Normal file
142
proxy/internal/debug/templates/tools.html
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
{{define "tools"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Client {{.AccountID}} - Tools</title>
|
||||||
|
<style>{{template "style"}}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Client: {{.AccountID}}</h1>
|
||||||
|
<div class="nav">
|
||||||
|
<a href="/debug">← Back</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}/tools" class="active">Tools</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}">Status</a>
|
||||||
|
<a href="/debug/clients/{{.AccountID}}/syncresponse">Sync Response</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>Client Control</h2>
|
||||||
|
<div class="form-row">
|
||||||
|
<div class="form-group">
|
||||||
|
<label> </label>
|
||||||
|
<button onclick="startClient()">Start</button>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label> </label>
|
||||||
|
<button onclick="stopClient()">Stop</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="client-result" class="result"></div>
|
||||||
|
|
||||||
|
<h2>Log Level</h2>
|
||||||
|
<div class="form-row">
|
||||||
|
<div class="form-group">
|
||||||
|
<label>Level</label>
|
||||||
|
<select id="log-level" style="width: 120px;">
|
||||||
|
<option value="trace">trace</option>
|
||||||
|
<option value="debug">debug</option>
|
||||||
|
<option value="info">info</option>
|
||||||
|
<option value="warn" selected>warn</option>
|
||||||
|
<option value="error">error</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label> </label>
|
||||||
|
<button onclick="setLogLevel()">Set Level</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="log-result" class="result"></div>
|
||||||
|
|
||||||
|
<h2>TCP Ping</h2>
|
||||||
|
<div class="form-row">
|
||||||
|
<div class="form-group">
|
||||||
|
<label>Host</label>
|
||||||
|
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label>Port</label>
|
||||||
|
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label> </label>
|
||||||
|
<button onclick="doTcpPing()">Connect</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="tcp-result" class="result"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const accountID = "{{.AccountID}}";
|
||||||
|
|
||||||
|
async function startClient() {
|
||||||
|
const resultDiv = document.getElementById('client-result');
|
||||||
|
resultDiv.innerHTML = '<span class="info">Starting client...</span>';
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/debug/clients/' + accountID + '/start');
|
||||||
|
const data = await resp.json();
|
||||||
|
if (data.success) {
|
||||||
|
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||||
|
} else {
|
||||||
|
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function stopClient() {
|
||||||
|
const resultDiv = document.getElementById('client-result');
|
||||||
|
resultDiv.innerHTML = '<span class="info">Stopping client...</span>';
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/debug/clients/' + accountID + '/stop');
|
||||||
|
const data = await resp.json();
|
||||||
|
if (data.success) {
|
||||||
|
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||||
|
} else {
|
||||||
|
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function setLogLevel() {
|
||||||
|
const level = document.getElementById('log-level').value;
|
||||||
|
const resultDiv = document.getElementById('log-result');
|
||||||
|
resultDiv.innerHTML = '<span class="info">Setting log level...</span>';
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/debug/clients/' + accountID + '/loglevel?level=' + level);
|
||||||
|
const data = await resp.json();
|
||||||
|
if (data.success) {
|
||||||
|
resultDiv.innerHTML = '<span class="success">✓ Log level set to: ' + data.level + '</span>';
|
||||||
|
} else {
|
||||||
|
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function doTcpPing() {
|
||||||
|
const host = document.getElementById('tcp-host').value;
|
||||||
|
const port = document.getElementById('tcp-port').value;
|
||||||
|
if (!host || !port) {
|
||||||
|
alert('Host and port required');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const resultDiv = document.getElementById('tcp-result');
|
||||||
|
resultDiv.innerHTML = '<span class="info">Connecting...</span>';
|
||||||
|
try {
|
||||||
|
const resp = await fetch('/debug/clients/' + accountID + '/pingtcp?host=' + encodeURIComponent(host) + '&port=' + port);
|
||||||
|
const data = await resp.json();
|
||||||
|
if (data.success) {
|
||||||
|
resultDiv.innerHTML = '<span class="success">✓ ' + data.host + ':' + data.port + ' connected in ' + data.latency + '</span>';
|
||||||
|
} else {
|
||||||
|
resultDiv.innerHTML = '<span class="error">✗ ' + data.host + ':' + data.port + ': ' + data.error + '</span>';
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}
|
||||||
@@ -3,6 +3,8 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestContextKey string
|
type requestContextKey string
|
||||||
@@ -18,7 +20,7 @@ const (
|
|||||||
type CapturedData struct {
|
type CapturedData struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ServiceId string
|
ServiceId string
|
||||||
AccountId string
|
AccountId types.AccountID
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetServiceId safely sets the service ID
|
// SetServiceId safely sets the service ID
|
||||||
@@ -36,14 +38,14 @@ func (c *CapturedData) GetServiceId() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetAccountId safely sets the account ID
|
// SetAccountId safely sets the account ID
|
||||||
func (c *CapturedData) SetAccountId(accountId string) {
|
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.AccountId = accountId
|
c.AccountId = accountId
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountId safely gets the account ID
|
// GetAccountId safely gets the account ID
|
||||||
func (c *CapturedData) GetAccountId() string {
|
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
return c.AccountId
|
return c.AccountId
|
||||||
@@ -76,13 +78,13 @@ func ServiceIdFromContext(ctx context.Context) string {
|
|||||||
}
|
}
|
||||||
return serviceId
|
return serviceId
|
||||||
}
|
}
|
||||||
func withAccountId(ctx context.Context, accountId string) context.Context {
|
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||||
return context.WithValue(ctx, accountIdKey, accountId)
|
return context.WithValue(ctx, accountIdKey, accountId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AccountIdFromContext(ctx context.Context) string {
|
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||||
v := ctx.Value(accountIdKey)
|
v := ctx.Value(accountIdKey)
|
||||||
accountId, ok := v.(string)
|
accountId, ok := v.(types.AccountID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReverseProxy struct {
|
type ReverseProxy struct {
|
||||||
@@ -36,8 +38,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Set the serviceId in the context for later retrieval.
|
// Set the serviceId in the context for later retrieval.
|
||||||
ctx := withServiceId(r.Context(), serviceId)
|
ctx := withServiceId(r.Context(), serviceId)
|
||||||
// Set the accountId in the context for later retrieval.
|
// Set the accountId in the context for later retrieval (for middleware).
|
||||||
ctx = withAccountId(ctx, accountID)
|
ctx = withAccountId(ctx, accountID)
|
||||||
|
// Set the accountId in the context for the roundtripper to use.
|
||||||
|
ctx = roundtrip.WithAccountID(ctx, accountID)
|
||||||
|
|
||||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||||
|
|||||||
@@ -6,16 +6,18 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Mapping struct {
|
type Mapping struct {
|
||||||
ID string
|
ID string
|
||||||
AccountID string
|
AccountID types.AccountID
|
||||||
Host string
|
Host string
|
||||||
Paths map[string]*url.URL
|
Paths map[string]*url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, string, bool) {
|
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
|
||||||
p.mappingsMux.RLock()
|
p.mappingsMux.RLock()
|
||||||
if p.mappings == nil {
|
if p.mappings == nil {
|
||||||
p.mappingsMux.RUnlock()
|
p.mappingsMux.RUnlock()
|
||||||
@@ -27,10 +29,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
|
|||||||
}
|
}
|
||||||
defer p.mappingsMux.RUnlock()
|
defer p.mappingsMux.RUnlock()
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(req.Host)
|
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||||
if err != nil {
|
host := req.Host
|
||||||
host = req.Host
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = h
|
||||||
}
|
}
|
||||||
|
|
||||||
m, exists := p.mappings[host]
|
m, exists := p.mappings[host]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, "", "", false
|
return nil, "", "", false
|
||||||
|
|||||||
@@ -4,18 +4,39 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/embed"
|
"github.com/netbirdio/netbird/client/embed"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const deviceNamePrefix = "ingress-"
|
const deviceNamePrefix = "ingress-proxy-"
|
||||||
|
|
||||||
|
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||||
|
var ErrNoAccountID = errors.New("no account ID in request context")
|
||||||
|
|
||||||
|
// domainInfo holds metadata about a registered domain.
|
||||||
|
type domainInfo struct {
|
||||||
|
reverseProxyID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||||
|
type clientEntry struct {
|
||||||
|
client *embed.Client
|
||||||
|
transport *http.Transport
|
||||||
|
domains map[domain.Domain]domainInfo
|
||||||
|
createdAt time.Time
|
||||||
|
started bool
|
||||||
|
}
|
||||||
|
|
||||||
type statusNotifier interface {
|
type statusNotifier interface {
|
||||||
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
||||||
@@ -23,147 +44,389 @@ type statusNotifier interface {
|
|||||||
|
|
||||||
// NetBird provides an http.RoundTripper implementation
|
// NetBird provides an http.RoundTripper implementation
|
||||||
// backed by underlying NetBird connections.
|
// backed by underlying NetBird connections.
|
||||||
|
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||||
type NetBird struct {
|
type NetBird struct {
|
||||||
mgmtAddr string
|
mgmtAddr string
|
||||||
|
proxyID string
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
clientsMux sync.RWMutex
|
clientsMux sync.RWMutex
|
||||||
clients map[string]*embed.Client
|
clients map[types.AccountID]*clientEntry
|
||||||
|
initLogOnce sync.Once
|
||||||
statusNotifier statusNotifier
|
statusNotifier statusNotifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
// NewNetBird creates a new NetBird transport.
|
||||||
|
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = log.StandardLogger()
|
logger = log.StandardLogger()
|
||||||
}
|
}
|
||||||
return &NetBird{
|
return &NetBird{
|
||||||
mgmtAddr: mgmtAddr,
|
mgmtAddr: mgmtAddr,
|
||||||
|
proxyID: proxyID,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
clients: make(map[string]*embed.Client),
|
clients: make(map[types.AccountID]*clientEntry),
|
||||||
statusNotifier: notifier,
|
statusNotifier: notifier,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
|
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||||
|
// one is created using the provided setup key. Multiple domains can share the same client.
|
||||||
|
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
|
||||||
|
n.clientsMux.Lock()
|
||||||
|
|
||||||
|
entry, exists := n.clients[accountID]
|
||||||
|
if exists {
|
||||||
|
// Client already exists for this account, just register the domain
|
||||||
|
entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID}
|
||||||
|
started := entry.started
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": d,
|
||||||
|
}).Debug("registered domain with existing client")
|
||||||
|
|
||||||
|
// If client is already started, notify this domain as connected immediately
|
||||||
|
if started && n.statusNotifier != nil {
|
||||||
|
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil {
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": d,
|
||||||
|
}).WithError(err).Warn("failed to notify status for existing client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n.initLogOnce.Do(func() {
|
||||||
|
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
||||||
|
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
wgPort := 0
|
||||||
client, err := embed.New(embed.Options{
|
client, err := embed.New(embed.Options{
|
||||||
DeviceName: deviceNamePrefix + domain,
|
DeviceName: deviceNamePrefix + n.proxyID,
|
||||||
ManagementURL: n.mgmtAddr,
|
ManagementURL: n.mgmtAddr,
|
||||||
SetupKey: key,
|
SetupKey: key,
|
||||||
LogOutput: io.Discard,
|
LogLevel: log.WarnLevel.String(),
|
||||||
BlockInbound: true,
|
BlockInbound: true,
|
||||||
|
WireguardPort: &wgPort,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
n.clientsMux.Unlock()
|
||||||
return fmt.Errorf("create netbird client: %w", err)
|
return fmt.Errorf("create netbird client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a transport using the client dialer. We do this instead of using
|
||||||
|
// the client's HTTPClient to avoid issues with request validation that do
|
||||||
|
// not work with reverse proxied requests.
|
||||||
|
entry = &clientEntry{
|
||||||
|
client: client,
|
||||||
|
domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}},
|
||||||
|
transport: &http.Transport{
|
||||||
|
DialContext: client.DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
},
|
||||||
|
createdAt: time.Now(),
|
||||||
|
started: false,
|
||||||
|
}
|
||||||
|
n.clients[accountID] = entry
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": d,
|
||||||
|
}).Info("created new client for account")
|
||||||
|
|
||||||
// Attempt to start the client in the background, if this fails
|
// Attempt to start the client in the background, if this fails
|
||||||
// then it is not ideal, but it isn't the end of the world because
|
// then it is not ideal, but it isn't the end of the world because
|
||||||
// we will try to start the client again before we use it.
|
// we will try to start the client again before we use it.
|
||||||
go func() {
|
go func() {
|
||||||
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err = client.Start(startCtx)
|
|
||||||
switch {
|
if err := client.Start(startCtx); err != nil {
|
||||||
case errors.Is(err, context.DeadlineExceeded):
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
n.logger.Debug("netbird client timed out")
|
n.logger.WithFields(log.Fields{
|
||||||
// This is not ideal, but we will try again later.
|
"account_id": accountID,
|
||||||
return
|
}).Debug("netbird client start timed out, will retry on first request")
|
||||||
case err != nil:
|
} else {
|
||||||
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
}).WithError(err).Error("failed to start netbird client")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify management that tunnel is now active
|
// Mark client as started and notify all registered domains
|
||||||
|
n.clientsMux.Lock()
|
||||||
|
entry, exists := n.clients[accountID]
|
||||||
|
if exists {
|
||||||
|
entry.started = true
|
||||||
|
}
|
||||||
|
// Copy domain info while holding lock
|
||||||
|
var domainsToNotify []struct {
|
||||||
|
domain domain.Domain
|
||||||
|
reverseProxyID string
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
for dom, info := range entry.domains {
|
||||||
|
domainsToNotify = append(domainsToNotify, struct {
|
||||||
|
domain domain.Domain
|
||||||
|
reverseProxyID string
|
||||||
|
}{domain: dom, reverseProxyID: info.reverseProxyID})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
|
// Notify all domains that they're connected
|
||||||
if n.statusNotifier != nil {
|
if n.statusNotifier != nil {
|
||||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
|
for _, domInfo := range domainsToNotify {
|
||||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
|
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil {
|
||||||
} else {
|
n.logger.WithFields(log.Fields{
|
||||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
|
"account_id": accountID,
|
||||||
|
"domain": domInfo.domain,
|
||||||
|
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||||
|
} else {
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": domInfo.domain,
|
||||||
|
}).Info("notified management about tunnel connection")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
n.clientsMux.Lock()
|
|
||||||
defer n.clientsMux.Unlock()
|
|
||||||
n.clients[domain] = client
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
|
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||||
n.clientsMux.RLock()
|
// when no domains are using it anymore.
|
||||||
client, exists := n.clients[domain]
|
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||||
n.clientsMux.RUnlock()
|
n.clientsMux.Lock()
|
||||||
|
|
||||||
|
entry, exists := n.clients[accountID]
|
||||||
if !exists {
|
if !exists {
|
||||||
// Mission failed successfully!
|
n.clientsMux.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := client.Stop(ctx); err != nil {
|
|
||||||
return fmt.Errorf("stop netbird client: %w", err)
|
// Get domain info before deleting
|
||||||
|
domInfo, domainExists := entry.domains[d]
|
||||||
|
if !domainExists {
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify management that tunnel is disconnected
|
delete(entry.domains, d)
|
||||||
|
|
||||||
|
// If there are still domains using this client, keep it running
|
||||||
|
if len(entry.domains) > 0 {
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": d,
|
||||||
|
"remaining_domains": len(entry.domains),
|
||||||
|
}).Debug("unregistered domain, client still in use")
|
||||||
|
|
||||||
|
// Notify this domain as disconnected
|
||||||
|
if n.statusNotifier != nil {
|
||||||
|
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
"domain": d,
|
||||||
|
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// No more domains using this client, stop it
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
}).Info("stopping client, no more domains")
|
||||||
|
|
||||||
|
client := entry.client
|
||||||
|
transport := entry.transport
|
||||||
|
delete(n.clients, accountID)
|
||||||
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
|
// Notify disconnection before stopping
|
||||||
if n.statusNotifier != nil {
|
if n.statusNotifier != nil {
|
||||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
|
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
|
n.logger.WithFields(log.Fields{
|
||||||
} else {
|
"account_id": accountID,
|
||||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
|
"domain": d,
|
||||||
|
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.clientsMux.Lock()
|
transport.CloseIdleConnections()
|
||||||
defer n.clientsMux.Unlock()
|
|
||||||
delete(n.clients, domain)
|
if err := client.Stop(ctx); err != nil {
|
||||||
|
n.logger.WithFields(log.Fields{
|
||||||
|
"account_id": accountID,
|
||||||
|
}).WithError(err).Warn("failed to stop netbird client")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||||
|
// specified in the request context and uses it to dial the backend.
|
||||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
host, _, err := net.SplitHostPort(req.Host)
|
accountID := AccountIDFromContext(req.Context())
|
||||||
if err != nil {
|
if accountID == "" {
|
||||||
host = req.Host
|
return nil, ErrNoAccountID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy references while holding lock, then unlock early to avoid blocking
|
||||||
|
// other requests during the potentially slow RoundTrip.
|
||||||
n.clientsMux.RLock()
|
n.clientsMux.RLock()
|
||||||
client, exists := n.clients[host]
|
entry, exists := n.clients[accountID]
|
||||||
// Immediately unlock after retrieval here rather than defer to avoid
|
|
||||||
// the call to client.Do blocking other clients being used whilst one
|
|
||||||
// is in use.
|
|
||||||
n.clientsMux.RUnlock()
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("no peer connection found for host: %s", host)
|
n.clientsMux.RUnlock()
|
||||||
|
return nil, fmt.Errorf("no peer connection found for account: %s", accountID)
|
||||||
}
|
}
|
||||||
|
client := entry.client
|
||||||
|
transport := entry.transport
|
||||||
|
n.clientsMux.RUnlock()
|
||||||
|
|
||||||
// Attempt to start the client, if the client is already running then
|
// Attempt to start the client, if the client is already running then
|
||||||
// it will return an error that we ignore, if this hits a timeout then
|
// it will return an error that we ignore, if this hits a timeout then
|
||||||
// this request is unprocessable.
|
// this request is unprocessable.
|
||||||
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err = client.Start(startCtx)
|
if err := client.Start(startCtx); err != nil {
|
||||||
switch {
|
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
||||||
case errors.Is(err, embed.ErrClientAlreadyStarted):
|
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||||
break
|
}
|
||||||
case err != nil:
|
|
||||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
n.logger.WithFields(log.Fields{
|
n.logger.WithFields(log.Fields{
|
||||||
"host": host,
|
"account_id": accountID,
|
||||||
|
"host": req.Host,
|
||||||
"url": req.URL.String(),
|
"url": req.URL.String(),
|
||||||
"requestURI": req.RequestURI,
|
"requestURI": req.RequestURI,
|
||||||
"method": req.Method,
|
"method": req.Method,
|
||||||
}).Debug("running roundtrip for peer connection")
|
}).Debug("running roundtrip for peer connection")
|
||||||
|
|
||||||
// Create a new transport using the client dialer and perform the roundtrip.
|
return transport.RoundTrip(req)
|
||||||
// We do this instead of using the client HTTPClient to avoid issues around
|
}
|
||||||
// client request validation that do not work with the reverse proxied
|
|
||||||
// requests.
|
// StopAll stops all clients.
|
||||||
// Other values are simply copied from the http.DefaultTransport which the
|
func (n *NetBird) StopAll(ctx context.Context) error {
|
||||||
// standard reverse proxy implementation would have used.
|
n.clientsMux.Lock()
|
||||||
// TODO: tune this transport for our needs.
|
defer n.clientsMux.Unlock()
|
||||||
return (&http.Transport{
|
|
||||||
DialContext: client.DialContext,
|
var merr *multierror.Error
|
||||||
MaxIdleConns: 100,
|
for accountID, entry := range n.clients {
|
||||||
IdleConnTimeout: 90 * time.Second,
|
entry.transport.CloseIdleConnections()
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
if err := entry.client.Stop(ctx); err != nil {
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
n.logger.WithFields(log.Fields{
|
||||||
}).RoundTrip(req)
|
"account_id": accountID,
|
||||||
|
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
maps.Clear(n.clients)
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasClient returns true if there is a client for the given account.
|
||||||
|
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||||
|
n.clientsMux.RLock()
|
||||||
|
defer n.clientsMux.RUnlock()
|
||||||
|
_, exists := n.clients[accountID]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// DomainCount returns the number of domains registered for the given account.
|
||||||
|
// Returns 0 if the account has no client.
|
||||||
|
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||||
|
n.clientsMux.RLock()
|
||||||
|
defer n.clientsMux.RUnlock()
|
||||||
|
entry, exists := n.clients[accountID]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(entry.domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientCount returns the total number of active clients.
|
||||||
|
func (n *NetBird) ClientCount() int {
|
||||||
|
n.clientsMux.RLock()
|
||||||
|
defer n.clientsMux.RUnlock()
|
||||||
|
return len(n.clients)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient returns the embed.Client for the given account ID.
|
||||||
|
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||||
|
n.clientsMux.RLock()
|
||||||
|
defer n.clientsMux.RUnlock()
|
||||||
|
entry, exists := n.clients[accountID]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return entry.client, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientDebugInfo contains debug information about a client.
|
||||||
|
type ClientDebugInfo struct {
|
||||||
|
AccountID types.AccountID
|
||||||
|
DomainCount int
|
||||||
|
Domains domain.List
|
||||||
|
HasClient bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||||
|
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||||
|
n.clientsMux.RLock()
|
||||||
|
defer n.clientsMux.RUnlock()
|
||||||
|
|
||||||
|
result := make(map[types.AccountID]ClientDebugInfo)
|
||||||
|
for accountID, entry := range n.clients {
|
||||||
|
domains := make(domain.List, 0, len(entry.domains))
|
||||||
|
for d := range entry.domains {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
result[accountID] = ClientDebugInfo{
|
||||||
|
AccountID: accountID,
|
||||||
|
DomainCount: len(entry.domains),
|
||||||
|
Domains: domains,
|
||||||
|
HasClient: entry.client != nil,
|
||||||
|
CreatedAt: entry.createdAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// accountIDContextKey is the context key for storing the account ID.
|
||||||
|
type accountIDContextKey struct{}
|
||||||
|
|
||||||
|
// WithAccountID adds the account ID to the context.
|
||||||
|
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||||
|
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountIDFromContext retrieves the account ID from the context.
|
||||||
|
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||||
|
v := ctx.Value(accountIDContextKey{})
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
accountID, ok := v.(types.AccountID)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return accountID
|
||||||
}
|
}
|
||||||
|
|||||||
247
proxy/internal/roundtrip/netbird_test.go
Normal file
247
proxy/internal/roundtrip/netbird_test.go
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
package roundtrip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||||
|
// It uses an invalid management URL to prevent real connections.
|
||||||
|
func mockNetBird() *NetBird {
|
||||||
|
return NewNetBird("http://invalid.test:9999", "test-proxy", nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
|
// Initially no client exists.
|
||||||
|
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||||
|
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||||
|
|
||||||
|
// Add first domain - this should create a new client.
|
||||||
|
// Note: This will fail to actually connect since we use an invalid URL,
|
||||||
|
// but the client entry should still be created.
|
||||||
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
|
// Add first domain.
|
||||||
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||||
|
|
||||||
|
// Add second domain for the same account - should reuse existing client.
|
||||||
|
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||||
|
|
||||||
|
// Add third domain.
|
||||||
|
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||||
|
|
||||||
|
// Still only one client.
|
||||||
|
assert.True(t, nb.HasClient(accountID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
account1 := types.AccountID("account-1")
|
||||||
|
account2 := types.AccountID("account-2")
|
||||||
|
|
||||||
|
// Add domain for account 1.
|
||||||
|
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add domain for account 2.
|
||||||
|
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Both accounts should have their own clients.
|
||||||
|
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||||
|
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
|
// Add multiple domains.
|
||||||
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||||
|
|
||||||
|
// Remove one domain - client should remain.
|
||||||
|
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||||
|
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||||
|
|
||||||
|
// Remove another domain - client should still remain.
|
||||||
|
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
|
// Add single domain.
|
||||||
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb.HasClient(accountID))
|
||||||
|
|
||||||
|
// Remove the only domain - client should be removed.
|
||||||
|
// Note: Stop() may fail since the client never actually connected,
|
||||||
|
// but the entry should still be removed from the map.
|
||||||
|
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||||
|
|
||||||
|
// After removing all domains, client should be gone.
|
||||||
|
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||||
|
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("nonexistent-account")
|
||||||
|
|
||||||
|
// Removing from non-existent account should not error.
|
||||||
|
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||||
|
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
|
// Add one domain.
|
||||||
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Remove non-existent domain - should not affect existing domain.
|
||||||
|
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Original domain should still be registered.
|
||||||
|
assert.True(t, nb.HasClient(accountID))
|
||||||
|
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
accountID := types.AccountID("test-account")
|
||||||
|
|
||||||
|
// Initially no account ID in context.
|
||||||
|
retrieved := AccountIDFromContext(ctx)
|
||||||
|
assert.True(t, retrieved == "", "should be empty when not set")
|
||||||
|
|
||||||
|
// Add account ID to context.
|
||||||
|
ctx = WithAccountID(ctx, accountID)
|
||||||
|
retrieved = AccountIDFromContext(ctx)
|
||||||
|
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
|
||||||
|
// Create context with wrong type for account ID key.
|
||||||
|
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
|
||||||
|
|
||||||
|
retrieved := AccountIDFromContext(ctx)
|
||||||
|
assert.True(t, retrieved == "", "should return empty for wrong type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
account1 := types.AccountID("account-1")
|
||||||
|
account2 := types.AccountID("account-2")
|
||||||
|
account3 := types.AccountID("account-3")
|
||||||
|
|
||||||
|
// Add domains for multiple accounts.
|
||||||
|
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||||
|
|
||||||
|
// Stop all clients.
|
||||||
|
// Note: StopAll may return errors since clients never actually connected,
|
||||||
|
// but the clients should still be removed from the map.
|
||||||
|
_ = nb.StopAll(context.Background())
|
||||||
|
|
||||||
|
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||||
|
assert.False(t, nb.HasClient(account1), "account1 should not have client")
|
||||||
|
assert.False(t, nb.HasClient(account2), "account2 should not have client")
|
||||||
|
assert.False(t, nb.HasClient(account3), "account3 should not have client")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_ClientCount(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
|
||||||
|
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||||
|
|
||||||
|
// Add clients for different accounts.
|
||||||
|
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, nb.ClientCount())
|
||||||
|
|
||||||
|
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, nb.ClientCount())
|
||||||
|
|
||||||
|
// Adding domain to existing account should not increase count.
|
||||||
|
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
|
||||||
|
// Create a request without account ID in context.
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// RoundTrip should fail because no account ID in context.
|
||||||
|
_, err = nb.RoundTrip(req)
|
||||||
|
require.ErrorIs(t, err, ErrNoAccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||||
|
nb := mockNetBird()
|
||||||
|
accountID := types.AccountID("nonexistent-account")
|
||||||
|
|
||||||
|
// Create a request with account ID but no client exists.
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req = req.WithContext(WithAccountID(req.Context(), accountID))
|
||||||
|
|
||||||
|
// RoundTrip should fail because no client for this account.
|
||||||
|
_, err = nb.RoundTrip(req)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||||
|
}
|
||||||
5
proxy/internal/types/types.go
Normal file
5
proxy/internal/types/types.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
// Package types defines common types used across the proxy package.
|
||||||
|
package types
|
||||||
|
|
||||||
|
// AccountID represents a unique identifier for a NetBird account.
|
||||||
|
type AccountID string
|
||||||
@@ -31,8 +31,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
@@ -45,6 +48,7 @@ type Server struct {
|
|||||||
auth *auth.Middleware
|
auth *auth.Middleware
|
||||||
http *http.Server
|
http *http.Server
|
||||||
https *http.Server
|
https *http.Server
|
||||||
|
debug *http.Server
|
||||||
|
|
||||||
// Mostly used for debugging on management.
|
// Mostly used for debugging on management.
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
@@ -62,6 +66,11 @@ type Server struct {
|
|||||||
OIDCClientSecret string
|
OIDCClientSecret string
|
||||||
OIDCEndpoint string
|
OIDCEndpoint string
|
||||||
OIDCScopes []string
|
OIDCScopes []string
|
||||||
|
|
||||||
|
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||||
|
DebugEndpointEnabled bool
|
||||||
|
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||||
|
DebugEndpointAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||||
@@ -148,7 +157,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
|
|
||||||
// Initialize the netbird client, this is required to build peer connections
|
// Initialize the netbird client, this is required to build peer connections
|
||||||
// to proxy over.
|
// to proxy over.
|
||||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger, s)
|
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.Logger, s)
|
||||||
|
|
||||||
// When generating ACME certificates, start a challenge server.
|
// When generating ACME certificates, start a challenge server.
|
||||||
tlsConfig := &tls.Config{}
|
tlsConfig := &tls.Config{}
|
||||||
@@ -204,6 +213,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
// Configure Access logs to management server.
|
// Configure Access logs to management server.
|
||||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
|
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
|
||||||
|
|
||||||
|
if s.DebugEndpointEnabled {
|
||||||
|
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||||
|
debugHandler := debug.NewHandler(s.netbird, s.Logger)
|
||||||
|
s.debug = &http.Server{
|
||||||
|
Addr: debugAddr,
|
||||||
|
Handler: debugHandler,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
s.Logger.WithField("address", debugAddr).Info("starting debug endpoint")
|
||||||
|
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
if err := s.debug.Close(); err != nil {
|
||||||
|
s.Logger.Debugf("debug endpoint close: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.netbird.StopAll(stopCtx); err != nil {
|
||||||
|
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Finally, start the reverse proxy.
|
// Finally, start the reverse proxy.
|
||||||
s.https = &http.Server{
|
s.https = &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
@@ -306,15 +343,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||||
domain := mapping.GetDomain()
|
d := domain.Domain(mapping.GetDomain())
|
||||||
accountID := mapping.GetAccountId()
|
accountID := types.AccountID(mapping.GetAccountId())
|
||||||
reverseProxyID := mapping.GetId()
|
reverseProxyID := mapping.GetId()
|
||||||
|
|
||||||
if err := s.netbird.AddPeer(ctx, domain, mapping.GetSetupKey(), accountID, reverseProxyID); err != nil {
|
if err := s.netbird.AddPeer(ctx, accountID, d, mapping.GetSetupKey(), reverseProxyID); err != nil {
|
||||||
return fmt.Errorf("create peer for domain %q: %w", domain, err)
|
return fmt.Errorf("create peer for domain %q: %w", d, err)
|
||||||
}
|
}
|
||||||
if s.acme != nil {
|
if s.acme != nil {
|
||||||
s.acme.AddDomain(domain, accountID, reverseProxyID)
|
s.acme.AddDomain(string(d), string(accountID), reverseProxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass the mapping through to the update function to avoid duplicating the
|
// Pass the mapping through to the update function to avoid duplicating the
|
||||||
@@ -360,10 +397,13 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||||
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain(), mapping.GetAccountId(), mapping.GetId()); err != nil {
|
d := domain.Domain(mapping.GetDomain())
|
||||||
|
accountID := types.AccountID(mapping.GetAccountId())
|
||||||
|
if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil {
|
||||||
s.Logger.WithFields(log.Fields{
|
s.Logger.WithFields(log.Fields{
|
||||||
"domain": mapping.GetDomain(),
|
"account_id": accountID,
|
||||||
"error": err,
|
"domain": d,
|
||||||
|
"error": err,
|
||||||
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
|
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
|
||||||
}
|
}
|
||||||
if s.acme != nil {
|
if s.acme != nil {
|
||||||
@@ -392,8 +432,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
|||||||
}
|
}
|
||||||
return proxy.Mapping{
|
return proxy.Mapping{
|
||||||
ID: mapping.GetId(),
|
ID: mapping.GetId(),
|
||||||
AccountID: mapping.AccountId,
|
AccountID: types.AccountID(mapping.GetAccountId()),
|
||||||
Host: mapping.GetDomain(),
|
Host: mapping.GetDomain(),
|
||||||
Paths: paths,
|
Paths: paths,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// debugEndpointAddr returns the address for the debug endpoint.
|
||||||
|
// If addr is empty, it defaults to localhost:8444 for security.
|
||||||
|
func debugEndpointAddr(addr string) string {
|
||||||
|
if addr == "" {
|
||||||
|
return "localhost:8444"
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|||||||
48
proxy/server_test.go
Normal file
48
proxy/server_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDebugEndpointDisabledByDefault(t *testing.T) {
|
||||||
|
s := &Server{}
|
||||||
|
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDebugEndpointAddr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty defaults to localhost",
|
||||||
|
input: "",
|
||||||
|
expected: "localhost:8444",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit localhost preserved",
|
||||||
|
input: "localhost:9999",
|
||||||
|
expected: "localhost:9999",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit address preserved",
|
||||||
|
input: "0.0.0.0:8444",
|
||||||
|
expected: "0.0.0.0:8444",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "127.0.0.1 preserved",
|
||||||
|
input: "127.0.0.1:8444",
|
||||||
|
expected: "127.0.0.1:8444",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := debugEndpointAddr(tc.input)
|
||||||
|
assert.Equal(t, tc.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user