diff --git a/proxy/cmd/root.go b/proxy/cmd/root.go index 4a02e3003..c982eb6ab 100644 --- a/proxy/cmd/root.go +++ b/proxy/cmd/root.go @@ -28,7 +28,6 @@ var ( func init() { rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "path to JSON configuration file (optional, can use env vars instead)") - // Set version information rootCmd.Version = version.Short() rootCmd.SetVersionTemplate("{{.Version}}\n") } @@ -39,14 +38,12 @@ func Execute() error { } func run(cmd *cobra.Command, args []string) error { - // Load configuration from file or environment variables config, err := proxy.LoadFromFileOrEnv(configFile) if err != nil { log.Fatalf("Failed to load configuration: %v", err) return err } - // Set log level setupLogging(config.LogLevel) log.Infof("Starting Netbird Proxy - %s", version.Short()) @@ -55,14 +52,12 @@ func run(cmd *cobra.Command, args []string) error { log.Infof("Listen Address: %s", config.ReverseProxy.ListenAddress) log.Infof("Log Level: %s", config.LogLevel) - // Create server instance server, err := proxy.NewServer(config) if err != nil { log.Fatalf("Failed to create server: %v", err) return err } - // Start server in a goroutine serverErrors := make(chan error, 1) go func() { if err := server.Start(); err != nil { @@ -70,11 +65,9 @@ func run(cmd *cobra.Command, args []string) error { } }() - // Set up signal handler for graceful shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - // Wait for either an error or shutdown signal select { case err := <-serverErrors: log.Fatalf("Server error: %v", err) @@ -83,11 +76,9 @@ func run(cmd *cobra.Command, args []string) error { log.Infof("Received signal: %v", sig) } - // Create shutdown context with timeout ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout) defer cancel() - // Gracefully stop the server if err := server.Stop(ctx); err != nil { log.Fatalf("Failed to stop server gracefully: %v", err) return err @@ -98,13 +89,11 @@ func run(cmd *cobra.Command, args []string) error { } func setupLogging(level string) { - // Set log format log.SetFormatter(&log.TextFormatter{ FullTimestamp: true, TimestampFormat: "2006-01-02 15:04:05", }) - // Set log level switch level { case "debug": log.SetLevel(log.DebugLevel) diff --git a/proxy/internal/auth/constants.go b/proxy/internal/auth/constants.go deleted file mode 100644 index 3c83d7c77..000000000 --- a/proxy/internal/auth/constants.go +++ /dev/null @@ -1,9 +0,0 @@ -package auth - -const ( - // DefaultSessionCookieName is the default cookie name for session storage - DefaultSessionCookieName = "auth_session" - - // ErrorInternalServer is the default internal server error message - ErrorInternalServer = "Internal Server Error" -) diff --git a/proxy/internal/auth/methods/basic_auth.go b/proxy/internal/auth/methods/basic_auth.go index cdee99c51..1f7f7f78a 100644 --- a/proxy/internal/auth/methods/basic_auth.go +++ b/proxy/internal/auth/methods/basic_auth.go @@ -18,7 +18,6 @@ func (c *BasicAuthConfig) Validate(r *http.Request) bool { return false } - // Use constant-time comparison to prevent timing attacks usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1 passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1 diff --git a/proxy/internal/auth/methods/bearer_auth.go b/proxy/internal/auth/methods/bearer_auth.go index 723e72f7f..c13b8d725 100644 --- a/proxy/internal/auth/methods/bearer_auth.go +++ b/proxy/internal/auth/methods/bearer_auth.go @@ -4,7 +4,5 @@ package methods // The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig // This just enables Bearer auth for a specific route type BearerConfig struct { - // Enable bearer token authentication for this route - // Uses the global OIDC configuration from proxy Config Enabled bool } diff --git a/proxy/internal/auth/methods/pin_auth.go b/proxy/internal/auth/methods/pin_auth.go index 21ce4b4f0..e860005fc 100644 --- a/proxy/internal/auth/methods/pin_auth.go +++ b/proxy/internal/auth/methods/pin_auth.go @@ -13,7 +13,7 @@ const ( // PINConfig holds PIN authentication settings type PINConfig struct { PIN string - Header string // Header name (default: "X-PIN") + Header string } // Validate checks PIN from the request header @@ -28,6 +28,5 @@ func (c *PINConfig) Validate(r *http.Request) bool { return false } - // Use constant-time comparison to prevent timing attacks return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1 } diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 13ba20276..1127dc67a 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -28,26 +28,22 @@ type authResult struct { // ServeHTTP implements the http.Handler interface func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // If no auth configured, allow request if m.config.IsEmpty() { m.allowWithoutAuth(w, r) return } - // Try to authenticate the request result := m.authenticate(w, r) if result == nil { // Authentication triggered a redirect (e.g., OIDC flow) return } - // Reject if authentication failed if !result.authenticated { m.rejectRequest(w, r) return } - // Authentication successful - continue to next handler m.continueWithAuth(w, r, result) } @@ -65,17 +61,14 @@ func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) { // authenticate attempts to authenticate the request using configured methods // Returns nil if a redirect occurred (e.g., OIDC flow initiated) func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult { - // Try Basic Auth if result := m.tryBasicAuth(r); result.authenticated { return result } - // Try PIN Auth if result := m.tryPINAuth(r); result.authenticated { return result } - // Try Bearer/OIDC Auth return m.tryBearerAuth(w, r) } @@ -94,7 +87,6 @@ func (m *Middleware) tryBasicAuth(r *http.Request) *authResult { method: "basic", } - // Extract username from Basic Auth if username, _, ok := r.BasicAuth(); ok { result.userID = username } @@ -128,24 +120,20 @@ func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *auth cookieName := m.oidcHandler.SessionCookieName() - // Handle auth token in query parameter (from OIDC callback) if m.handleAuthTokenParameter(w, r, cookieName) { - return nil // Redirect occurred + return nil } - // Try session cookie if result := m.trySessionCookie(r, cookieName); result.authenticated { return result } - // Try Authorization header if result := m.tryAuthorizationHeader(r); result.authenticated { return result } - // No valid auth - redirect to OIDC provider m.oidcHandler.RedirectToProvider(w, r, m.routeID) - return nil // Redirect occurred + return nil } // handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback @@ -161,7 +149,6 @@ func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Req "host": r.Host, }).Info("Found auth token in query parameter, setting cookie and redirecting") - // Validate the token before setting cookie if !m.oidcHandler.ValidateJWT(authToken) { log.WithFields(log.Fields{ "route_id": m.routeID, @@ -169,7 +156,6 @@ func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Req return false } - // Set session cookie cookie := &http.Cookie{ Name: cookieName, Value: authToken, @@ -288,7 +274,7 @@ func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, re "path": r.URL.Path, }).Debug("Authentication successful") - // Store auth info in headers for logging + // TODO: Find other means of auth logging than headers r.Header.Set("X-Auth-Method", result.method) r.Header.Set("X-Auth-User-ID", result.userID) @@ -299,7 +285,7 @@ func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, re // Wrap wraps an HTTP handler with authentication middleware func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler { if authConfig == nil { - authConfig = &Config{} // Empty config = no auth + authConfig = &Config{} } return &Middleware{ diff --git a/proxy/internal/auth/oidc/config.go b/proxy/internal/auth/oidc/config.go index 0060d7d3f..2e79b1402 100644 --- a/proxy/internal/auth/oidc/config.go +++ b/proxy/internal/auth/oidc/config.go @@ -2,19 +2,16 @@ package oidc // Config holds the global OIDC/OAuth configuration type Config struct { - // OIDC Provider settings - ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com") - ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID - ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients) - RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback") - Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"]) + ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` + ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` + ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` + RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` + Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` - // JWT Validation settings - JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys - JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim - JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims - JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys + JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` + JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` + JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` + JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` - // Session settings - SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session") + SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` } diff --git a/proxy/internal/auth/oidc/handler.go b/proxy/internal/auth/oidc/handler.go index 6689bcaf5..4896aceda 100644 --- a/proxy/internal/auth/oidc/handler.go +++ b/proxy/internal/auth/oidc/handler.go @@ -24,7 +24,6 @@ type Handler struct { // NewHandler creates a new OIDC handler func NewHandler(config *Config, stateStore *StateStore) *Handler { - // Initialize JWT validator var jwtValidator *jwt.Validator if config.JWTKeysLocation != "" { jwtValidator = jwt.NewValidator( @@ -67,7 +66,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou scopes = []string{"openid", "profile", "email"} } - // Build authorization URL authURL, err := url.Parse(h.config.ProviderURL) if err != nil { log.WithError(err).Error("Invalid OIDC provider URL") @@ -80,7 +78,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize" } - // Build query parameters params := url.Values{} params.Set("client_id", h.config.ClientID) params.Set("redirect_uri", h.config.RedirectURL) @@ -88,8 +85,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou params.Set("scope", strings.Join(scopes, " ")) params.Set("state", state) - // Add audience parameter to get an access token for the API - // This ensures we get a proper JWT for the API audience, not just an ID token if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID { params.Set("audience", h.config.JWTAudience[0]) } @@ -103,7 +98,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou "state": state, }).Info("Redirecting to OIDC provider for authentication") - // Redirect user to identity provider login page http.Redirect(w, r, authURL.String(), http.StatusFound) } diff --git a/proxy/internal/auth/oidc/state_store.go b/proxy/internal/auth/oidc/state_store.go index 6420c16a4..6d6c4a4d5 100644 --- a/proxy/internal/auth/oidc/state_store.go +++ b/proxy/internal/auth/oidc/state_store.go @@ -34,7 +34,6 @@ func (s *StateStore) Store(stateToken, originalURL, routeID string) { RouteID: routeID, } - // Clean up expired states s.cleanup() } diff --git a/proxy/internal/reverseproxy/certmanager/letsencrypt.go b/proxy/internal/reverseproxy/certmanager/letsencrypt.go index ed75379cf..ac36c0707 100644 --- a/proxy/internal/reverseproxy/certmanager/letsencrypt.go +++ b/proxy/internal/reverseproxy/certmanager/letsencrypt.go @@ -39,7 +39,7 @@ func NewLetsEncrypt(config LetsEncryptConfig) *LetsEncryptManager { HostPolicy: m.hostPolicy, Cache: autocert.DirCache(config.CertCacheDir), Email: config.Email, - RenewBefore: 0, // Use default + RenewBefore: 0, // Use default 30 days prior to expiration } log.Info("Let's Encrypt certificate manager initialized") @@ -71,8 +71,6 @@ func (m *LetsEncryptManager) RemoveDomain(domain string) { func (m *LetsEncryptManager) IssueCertificate(ctx context.Context, domain string) error { log.Infof("Issuing Let's Encrypt certificate for domain: %s", domain) - // Use GetCertificate to trigger certificate issuance - // This will go through the ACME challenge flow hello := &tls.ClientHelloInfo{ ServerName: domain, } diff --git a/proxy/internal/reverseproxy/certmanager/selfsigned_manager.go b/proxy/internal/reverseproxy/certmanager/selfsigned_manager.go index 6b1894743..6a20a9bbe 100644 --- a/proxy/internal/reverseproxy/certmanager/selfsigned_manager.go +++ b/proxy/internal/reverseproxy/certmanager/selfsigned_manager.go @@ -54,19 +54,16 @@ func (m *SelfSignedManager) IssueCertificate(ctx context.Context, domain string) m.mu.Lock() defer m.mu.Unlock() - // Check if we already have a certificate for this domain if _, exists := m.certificates[domain]; exists { log.Debugf("Self-signed certificate already exists for domain: %s", domain) return nil } - // Generate self-signed certificate cert, err := m.generateCertificate(domain) if err != nil { return err } - // Cache the certificate m.certificates[domain] = cert return nil @@ -94,7 +91,6 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer return cert, nil } - // Generate certificate on-demand if not cached log.Infof("Generating self-signed certificate on-demand for: %s", hello.ServerName) newCert, err := m.generateCertificate(hello.ServerName) @@ -102,7 +98,6 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer return nil, err } - // Cache it m.mu.Lock() m.certificates[hello.ServerName] = newCert m.mu.Unlock() @@ -112,13 +107,11 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer // generateCertificate generates a self-signed certificate for a domain func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate, error) { - // Generate private key priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, fmt.Errorf("failed to generate private key: %w", err) } - // Create certificate template notBefore := time.Now() notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year @@ -141,13 +134,11 @@ func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate DNSNames: []string{domain}, } - // Create self-signed certificate certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return nil, fmt.Errorf("failed to create certificate: %w", err) } - // Parse certificate cert, err := x509.ParseCertificate(certDER) if err != nil { return nil, fmt.Errorf("failed to parse certificate: %w", err) diff --git a/proxy/internal/reverseproxy/handler.go b/proxy/internal/reverseproxy/handler.go index be572365b..16eda9e64 100644 --- a/proxy/internal/reverseproxy/handler.go +++ b/proxy/internal/reverseproxy/handler.go @@ -53,18 +53,13 @@ func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) { host = host[:idx] } - // Get auth info from headers set by auth middleware + // TODO: extract logging data authMechanism := r.Header.Get("X-Auth-Method") if authMechanism == "" { authMechanism = "none" } - userID := r.Header.Get("X-Auth-User-ID") - - // Determine auth success based on status code authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden - - // Extract source IP directly sourceIP := extractSourceIP(r) data := RequestData{ @@ -89,29 +84,22 @@ func (p *Proxy) findRoute(host, path string) *routeEntry { p.mu.RLock() defer p.mu.RUnlock() - // Strip port from host if idx := strings.LastIndex(host, ":"); idx != -1 { host = host[:idx] } - // O(1) lookup by host routeConfig, exists := p.routes[host] if !exists { return nil } - // Build list of route entries sorted by path specificity var entries []*routeEntry - // Create entries for each path mapping for routePath, target := range routeConfig.PathMappings { proxy := p.createProxy(routeConfig, target) - // ALWAYS wrap proxy with auth middleware (even if no auth configured) - // This ensures consistent auth handling and logging handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler) - // Log auth configuration if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() { var authType string if routeConfig.AuthConfig.BasicAuth != nil { @@ -169,11 +157,9 @@ func (p *Proxy) findRoute(host, path string) *routeEntry { // createProxy creates a reverse proxy for a target with the route's connection func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy { - // Parse target URL targetURL, err := url.Parse("http://" + target) if err != nil { log.Errorf("Failed to parse target URL %s: %v", target, err) - // Return a proxy that returns 502 return &httputil.ReverseProxy{ Director: func(req *http.Request) {}, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { @@ -182,21 +168,18 @@ func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.R } } - // Create reverse proxy proxy := httputil.NewSingleHostReverseProxy(targetURL) - // Configure transport to use the provided connection (WireGuard, etc.) proxy.Transport = &http.Transport{ DialContext: routeConfig.nbClient.DialContext, MaxIdleConns: 1, MaxIdleConnsPerHost: 1, - IdleConnTimeout: 0, // Keep alive indefinitely + IdleConnTimeout: 0, DisableKeepAlives: false, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } - // Custom error handler proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err) http.Error(w, "Bad Gateway", http.StatusBadGateway) @@ -207,14 +190,12 @@ func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.R // handleOIDCCallback handles the global /auth/callback endpoint for all routes func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) { - // Check if OIDC handler is available if p.oidcHandler == nil { log.Error("OIDC callback received but no OIDC handler configured") http.Error(w, "Authentication not configured", http.StatusInternalServerError) return } - // Use the OIDC handler's callback method handler := p.oidcHandler.HandleCallback() handler(w, r) } diff --git a/proxy/internal/reverseproxy/proxy.go b/proxy/internal/reverseproxy/proxy.go index d5d9ed20a..b07536358 100644 --- a/proxy/internal/reverseproxy/proxy.go +++ b/proxy/internal/reverseproxy/proxy.go @@ -27,7 +27,6 @@ type Proxy struct { // New creates a new reverse proxy func New(config Config) (*Proxy, error) { - // Set defaults if config.ListenAddress == "" { config.ListenAddress = ":443" } @@ -38,22 +37,18 @@ func New(config Config) (*Proxy, error) { config.CertCacheDir = "./certs" } - // Set default cert mode if config.CertMode == "" { config.CertMode = "letsencrypt" } - // Validate config based on cert mode if config.CertMode == "letsencrypt" && config.TLSEmail == "" { return nil, fmt.Errorf("TLSEmail is required for letsencrypt mode") } - // Set default OIDC session cookie name if not provided if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" { config.OIDCConfig.SessionCookieName = "auth_session" } - // Initialize certificate manager based on mode var certMgr certmanager.Manager if config.CertMode == "selfsigned" { // HTTPS with self-signed certificates (for local testing) @@ -73,8 +68,6 @@ func New(config Config) (*Proxy, error) { isRunning: false, } - // Initialize OIDC handler if OIDC is configured - // The handler internally creates and manages its own state store if config.OIDCConfig != nil { stateStore := oidc.NewStateStore() p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore) @@ -93,28 +86,25 @@ func (p *Proxy) Start() error { p.isRunning = true p.mu.Unlock() - // Build the main HTTP handler handler := p.buildHandler() return p.startHTTPS(handler) } -// startHTTPS starts the proxy with HTTPS (non-blocking) +// startHTTPS starts the proxy with HTTPS func (p *Proxy) startHTTPS(handler http.Handler) error { - // Start HTTP server for ACME challenges (Let's Encrypt HTTP-01) p.httpServer = &http.Server{ Addr: p.config.HTTPListenAddress, Handler: p.certManager.HTTPHandler(nil), } go func() { - log.Infof("Starting HTTP server for ACME challenges on %s", p.config.HTTPListenAddress) + log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress) if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Errorf("HTTP server error: %v", err) } }() - // Start HTTPS server in background p.server = &http.Server{ Addr: p.config.ListenAddress, Handler: handler, @@ -143,14 +133,12 @@ func (p *Proxy) Stop(ctx context.Context) error { log.Info("Stopping reverse proxy server...") - // Stop HTTP server (for ACME challenges) if p.httpServer != nil { if err := p.httpServer.Shutdown(ctx); err != nil { log.Errorf("Error shutting down HTTP server: %v", err) } } - // Stop main server if p.server != nil { if err := p.server.Shutdown(ctx); err != nil { return fmt.Errorf("error shutting down server: %w", err) diff --git a/proxy/internal/reverseproxy/routes.go b/proxy/internal/reverseproxy/routes.go index 16b6eb1bd..fd59286ff 100644 --- a/proxy/internal/reverseproxy/routes.go +++ b/proxy/internal/reverseproxy/routes.go @@ -36,7 +36,6 @@ func (p *Proxy) AddRoute(route *RouteConfig) error { p.mu.Lock() defer p.mu.Unlock() - // Check if route already exists for this domain if _, exists := p.routes[route.Domain]; exists { return fmt.Errorf("route for domain %s already exists", route.Domain) } @@ -54,10 +53,8 @@ func (p *Proxy) AddRoute(route *RouteConfig) error { route.nbClient = client - // Add route with domain as key p.routes[route.Domain] = route - // Register domain with certificate manager p.certManager.AddDomain(route.Domain) log.WithFields(log.Fields{ @@ -66,13 +63,13 @@ func (p *Proxy) AddRoute(route *RouteConfig) error { "paths": len(route.PathMappings), }).Info("Added route") - // Eagerly issue certificate in background go func(domain string) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() if err := p.certManager.IssueCertificate(ctx, domain); err != nil { log.Errorf("Failed to issue certificate: %v", err) + // TODO: Better error feedback mechanism } }(route.Domain) @@ -84,15 +81,12 @@ func (p *Proxy) RemoveRoute(domain string) error { p.mu.Lock() defer p.mu.Unlock() - // Check if route exists if _, exists := p.routes[domain]; !exists { return fmt.Errorf("route for domain %s not found", domain) } - // Remove route delete(p.routes, domain) - // Unregister domain from certificate manager p.certManager.RemoveDomain(domain) log.Infof("Removed route for domain: %s", domain) @@ -114,12 +108,10 @@ func (p *Proxy) UpdateRoute(route *RouteConfig) error { p.mu.Lock() defer p.mu.Unlock() - // Check if route exists for this domain if _, exists := p.routes[route.Domain]; !exists { return fmt.Errorf("route for domain %s not found", route.Domain) } - // Update route using domain as key p.routes[route.Domain] = route log.WithFields(log.Fields{ diff --git a/proxy/pkg/errors/common.go b/proxy/pkg/errors/common.go deleted file mode 100644 index 1b06d43a2..000000000 --- a/proxy/pkg/errors/common.go +++ /dev/null @@ -1,73 +0,0 @@ -package errors - -import "fmt" - -// Configuration errors - -func NewConfigInvalid(message string) *AppError { - return New(CodeConfigInvalid, message) -} - -func NewConfigNotFound(path string) *AppError { - return New(CodeConfigNotFound, fmt.Sprintf("configuration file not found: %s", path)) -} - -func WrapConfigParseFailed(err error, path string) *AppError { - return Wrap(CodeConfigParseFailed, fmt.Sprintf("failed to parse configuration file: %s", path), err) -} - -// Server errors - -func NewServerStartFailed(err error, reason string) *AppError { - return Wrap(CodeServerStartFailed, fmt.Sprintf("server start failed: %s", reason), err) -} - -func NewServerStopFailed(err error) *AppError { - return Wrap(CodeServerStopFailed, "server shutdown failed", err) -} - -func NewServerAlreadyRunning() *AppError { - return New(CodeServerAlreadyRunning, "server is already running") -} - -func NewServerNotRunning() *AppError { - return New(CodeServerNotRunning, "server is not running") -} - -// Proxy errors - -func NewProxyBackendUnavailable(backend string, err error) *AppError { - return Wrap(CodeProxyBackendUnavailable, fmt.Sprintf("backend unavailable: %s", backend), err) -} - -func NewProxyTimeout(backend string) *AppError { - return New(CodeProxyTimeout, fmt.Sprintf("request to backend timed out: %s", backend)) -} - -func NewProxyInvalidTarget(target string, err error) *AppError { - return Wrap(CodeProxyInvalidTarget, fmt.Sprintf("invalid proxy target: %s", target), err) -} - -// Network errors - -func NewNetworkTimeout(operation string) *AppError { - return New(CodeNetworkTimeout, fmt.Sprintf("network timeout: %s", operation)) -} - -func NewNetworkUnreachable(host string) *AppError { - return New(CodeNetworkUnreachable, fmt.Sprintf("network unreachable: %s", host)) -} - -func NewNetworkRefused(host string) *AppError { - return New(CodeNetworkRefused, fmt.Sprintf("connection refused: %s", host)) -} - -// Internal errors - -func NewInternalError(message string) *AppError { - return New(CodeInternalError, message) -} - -func WrapInternalError(err error, message string) *AppError { - return Wrap(CodeInternalError, message, err) -} diff --git a/proxy/pkg/errors/errors.go b/proxy/pkg/errors/errors.go deleted file mode 100644 index 1fa15aa48..000000000 --- a/proxy/pkg/errors/errors.go +++ /dev/null @@ -1,138 +0,0 @@ -package errors - -import ( - "errors" - "fmt" -) - -// Error codes for categorizing errors -type Code string - -const ( - // Configuration errors - CodeConfigInvalid Code = "CONFIG_INVALID" - CodeConfigNotFound Code = "CONFIG_NOT_FOUND" - CodeConfigParseFailed Code = "CONFIG_PARSE_FAILED" - - // Server errors - CodeServerStartFailed Code = "SERVER_START_FAILED" - CodeServerStopFailed Code = "SERVER_STOP_FAILED" - CodeServerAlreadyRunning Code = "SERVER_ALREADY_RUNNING" - CodeServerNotRunning Code = "SERVER_NOT_RUNNING" - - // Proxy errors - CodeProxyBackendUnavailable Code = "PROXY_BACKEND_UNAVAILABLE" - CodeProxyTimeout Code = "PROXY_TIMEOUT" - CodeProxyInvalidTarget Code = "PROXY_INVALID_TARGET" - - // Network errors - CodeNetworkTimeout Code = "NETWORK_TIMEOUT" - CodeNetworkUnreachable Code = "NETWORK_UNREACHABLE" - CodeNetworkRefused Code = "NETWORK_REFUSED" - - // Internal errors - CodeInternalError Code = "INTERNAL_ERROR" - CodeUnknownError Code = "UNKNOWN_ERROR" -) - -// AppError represents a structured application error -type AppError struct { - Code Code // Error code for categorization - Message string // Human-readable error message - Cause error // Underlying error (if any) -} - -// Error implements the error interface -func (e *AppError) Error() string { - if e.Cause != nil { - return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause) - } - return fmt.Sprintf("[%s] %s", e.Code, e.Message) -} - -// Unwrap returns the underlying error (for errors.Is and errors.As) -func (e *AppError) Unwrap() error { - return e.Cause -} - -// Is checks if the error matches the target -func (e *AppError) Is(target error) bool { - t, ok := target.(*AppError) - if !ok { - return false - } - return e.Code == t.Code -} - -// New creates a new AppError -func New(code Code, message string) *AppError { - return &AppError{ - Code: code, - Message: message, - } -} - -// Wrap wraps an existing error with additional context -func Wrap(code Code, message string, cause error) *AppError { - return &AppError{ - Code: code, - Message: message, - Cause: cause, - } -} - -// Wrapf wraps an error with a formatted message -func Wrapf(code Code, cause error, format string, args ...interface{}) *AppError { - return &AppError{ - Code: code, - Message: fmt.Sprintf(format, args...), - Cause: cause, - } -} - -// GetCode extracts the error code from an error -func GetCode(err error) Code { - var appErr *AppError - if errors.As(err, &appErr) { - return appErr.Code - } - return CodeUnknownError -} - -// HasCode checks if an error has a specific code -func HasCode(err error, code Code) bool { - return GetCode(err) == code -} - -// IsConfigError checks if an error is configuration-related -func IsConfigError(err error) bool { - code := GetCode(err) - return code == CodeConfigInvalid || - code == CodeConfigNotFound || - code == CodeConfigParseFailed -} - -// IsServerError checks if an error is server-related -func IsServerError(err error) bool { - code := GetCode(err) - return code == CodeServerStartFailed || - code == CodeServerStopFailed || - code == CodeServerAlreadyRunning || - code == CodeServerNotRunning -} - -// IsProxyError checks if an error is proxy-related -func IsProxyError(err error) bool { - code := GetCode(err) - return code == CodeProxyBackendUnavailable || - code == CodeProxyTimeout || - code == CodeProxyInvalidTarget -} - -// IsNetworkError checks if an error is network-related -func IsNetworkError(err error) bool { - code := GetCode(err) - return code == CodeNetworkTimeout || - code == CodeNetworkUnreachable || - code == CodeNetworkRefused -} diff --git a/proxy/pkg/errors/errors_test.go b/proxy/pkg/errors/errors_test.go deleted file mode 100644 index cb2f0a827..000000000 --- a/proxy/pkg/errors/errors_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package errors - -import ( - "errors" - "testing" -) - -func TestAppError_Error(t *testing.T) { - tests := []struct { - name string - err *AppError - expected string - }{ - { - name: "error without cause", - err: New(CodeConfigInvalid, "invalid configuration"), - expected: "[CONFIG_INVALID] invalid configuration", - }, - { - name: "error with cause", - err: Wrap(CodeServerStartFailed, "failed to bind port", errors.New("address already in use")), - expected: "[SERVER_START_FAILED] failed to bind port: address already in use", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.err.Error(); got != tt.expected { - t.Errorf("Error() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestGetCode(t *testing.T) { - tests := []struct { - name string - err error - expected Code - }{ - { - name: "app error", - err: New(CodeConfigInvalid, "test"), - expected: CodeConfigInvalid, - }, - { - name: "wrapped app error", - err: Wrap(CodeServerStartFailed, "test", errors.New("cause")), - expected: CodeServerStartFailed, - }, - { - name: "standard error", - err: errors.New("standard error"), - expected: CodeUnknownError, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := GetCode(tt.err); got != tt.expected { - t.Errorf("GetCode() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestHasCode(t *testing.T) { - err := New(CodeConfigInvalid, "invalid config") - - if !HasCode(err, CodeConfigInvalid) { - t.Error("HasCode() should return true for matching code") - } - - if HasCode(err, CodeServerStartFailed) { - t.Error("HasCode() should return false for non-matching code") - } -} - -func TestIsConfigError(t *testing.T) { - tests := []struct { - name string - err error - expected bool - }{ - { - name: "config invalid error", - err: New(CodeConfigInvalid, "test"), - expected: true, - }, - { - name: "config not found error", - err: New(CodeConfigNotFound, "test"), - expected: true, - }, - { - name: "server error", - err: New(CodeServerStartFailed, "test"), - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsConfigError(tt.err); got != tt.expected { - t.Errorf("IsConfigError() = %v, want %v", got, tt.expected) - } - }) - } -} - -func TestErrorUnwrap(t *testing.T) { - cause := errors.New("root cause") - err := Wrap(CodeInternalError, "wrapped error", cause) - - unwrapped := errors.Unwrap(err) - if unwrapped != cause { - t.Errorf("Unwrap() = %v, want %v", unwrapped, cause) - } -} - -func TestErrorIs(t *testing.T) { - err1 := New(CodeConfigInvalid, "test1") - err2 := New(CodeConfigInvalid, "test2") - err3 := New(CodeServerStartFailed, "test3") - - if !errors.Is(err1, err2) { - t.Error("errors.Is() should return true for same error code") - } - - if errors.Is(err1, err3) { - t.Error("errors.Is() should return false for different error codes") - } -} - -func TestCommonConstructors(t *testing.T) { - t.Run("NewConfigNotFound", func(t *testing.T) { - err := NewConfigNotFound("/path/to/config") - if GetCode(err) != CodeConfigNotFound { - t.Error("NewConfigNotFound should create CONFIG_NOT_FOUND error") - } - }) - - t.Run("NewServerAlreadyRunning", func(t *testing.T) { - err := NewServerAlreadyRunning() - if GetCode(err) != CodeServerAlreadyRunning { - t.Error("NewServerAlreadyRunning should create SERVER_ALREADY_RUNNING error") - } - }) - - t.Run("NewProxyBackendUnavailable", func(t *testing.T) { - cause := errors.New("connection refused") - err := NewProxyBackendUnavailable("http://backend", cause) - if GetCode(err) != CodeProxyBackendUnavailable { - t.Error("NewProxyBackendUnavailable should create PROXY_BACKEND_UNAVAILABLE error") - } - if !errors.Is(err.Unwrap(), cause) { - t.Error("NewProxyBackendUnavailable should wrap the cause") - } - }) -} diff --git a/proxy/pkg/grpc/server.go b/proxy/pkg/grpc/server.go index 15e3285c4..21c3d68ba 100644 --- a/proxy/pkg/grpc/server.go +++ b/proxy/pkg/grpc/server.go @@ -77,7 +77,6 @@ func (s *Server) Start() error { return fmt.Errorf("failed to listen: %w", err) } - // Configure gRPC server with keepalive s.grpcServer = grpc.NewServer( grpc.KeepaliveParams(keepalive.ServerParameters{ Time: 30 * time.Second, @@ -114,7 +113,6 @@ func (s *Server) Stop(ctx context.Context) error { log.Info("Stopping gRPC server...") - // Cancel all active streams s.mu.Lock() for _, streamCtx := range s.streams { streamCtx.cancel() @@ -123,7 +121,6 @@ func (s *Server) Stop(ctx context.Context) error { s.streams = make(map[string]*StreamContext) s.mu.Unlock() - // Graceful stop with timeout stopped := make(chan struct{}) go func() { s.grpcServer.GracefulStop() @@ -154,7 +151,6 @@ func (s *Server) Stream(stream pb.ProxyService_StreamServer) error { controlID := fmt.Sprintf("control-%d", time.Now().Unix()) - // Create stream context streamCtx := &StreamContext{ stream: stream, sendChan: make(chan *pb.ProxyMessage, 100), @@ -163,22 +159,18 @@ func (s *Server) Stream(stream pb.ProxyService_StreamServer) error { controlID: controlID, } - // Register stream s.mu.Lock() s.streams[controlID] = streamCtx s.mu.Unlock() log.Infof("Control service connected: %s", controlID) - // Start goroutine to send ProxyMessages to control service sendDone := make(chan error, 1) go s.sendLoop(streamCtx, sendDone) - // Start goroutine to receive ControlMessages from control service recvDone := make(chan error, 1) go s.receiveLoop(streamCtx, recvDone) - // Wait for either send or receive to complete select { case err := <-sendDone: log.Infof("Control service %s send loop ended: %v", controlID, err) @@ -202,7 +194,6 @@ func (s *Server) sendLoop(streamCtx *StreamContext, done chan<- error) { return } - // Send ProxyMessage to control service if err := streamCtx.stream.Send(msg); err != nil { log.Errorf("Failed to send message to control service: %v", err) done <- err @@ -219,7 +210,6 @@ func (s *Server) sendLoop(streamCtx *StreamContext, done chan<- error) { // receiveLoop handles receiving ControlMessages from the control service func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) { for { - // Receive ControlMessage from control service (client) controlMsg, err := streamCtx.stream.Recv() if err != nil { log.Debugf("Stream receive error: %v", err) @@ -227,7 +217,6 @@ func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) { return } - // Handle different ControlMessage types switch m := controlMsg.Message.(type) { case *pb.ControlMessage_Event: if s.handler != nil { @@ -271,7 +260,6 @@ func (s *Server) SendProxyMessage(msg *pb.ProxyMessage) { for _, streamCtx := range s.streams { select { case streamCtx.sendChan <- msg: - // Message queued successfully default: log.Warn("Send channel full, dropping message") } diff --git a/proxy/pkg/proxy/config.go b/proxy/pkg/proxy/config.go index 6f7c86902..0bd7a13ce 100644 --- a/proxy/pkg/proxy/config.go +++ b/proxy/pkg/proxy/config.go @@ -63,7 +63,7 @@ type Config struct { // LogLevel sets the logging verbosity (debug, info, warn, error) LogLevel string `env:"NB_PROXY_LOG_LEVEL" envDefault:"info" json:"log_level"` - // GRPCListenAddress is the address for the gRPC control server (empty to disable) + // GRPCListenAddress is the address for the gRPC control server GRPCListenAddress string `env:"NB_PROXY_GRPC_LISTEN_ADDRESS" envDefault:":50051" json:"grpc_listen_address"` // ProxyID is a unique identifier for this proxy instance @@ -111,7 +111,6 @@ func LoadFromFile(path string) (Config, error) { } // LoadFromFileOrEnv loads configuration from a file if path is provided, otherwise from environment variables -// Environment variables will override file-based configuration if both are present func LoadFromFileOrEnv(configPath string) (Config, error) { var cfg Config @@ -123,7 +122,6 @@ func LoadFromFileOrEnv(configPath string) (Config, error) { } cfg = fileCfg } else { - // Parse environment variables (will override file config with any set env vars) if err := env.Parse(&cfg); err != nil { return Config{}, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err) } @@ -137,30 +135,24 @@ func LoadFromFileOrEnv(configPath string) (Config, error) { } // UnmarshalJSON implements custom JSON unmarshaling with automatic duration parsing -// Uses reflection to find all time.Duration fields and parse them from string func (c *Config) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to get raw values var raw map[string]interface{} if err := json.Unmarshal(data, &raw); err != nil { return err } - // Get reflection value and type val := reflect.ValueOf(c).Elem() typ := val.Type() - // Iterate through all fields for i := 0; i < val.NumField(); i++ { field := val.Field(i) fieldType := typ.Field(i) - // Get JSON tag name jsonTag := fieldType.Tag.Get("json") if jsonTag == "" || jsonTag == "-" { continue } - // Parse tag to get field name (handle omitempty, etc.) jsonFieldName := jsonTag if idx := len(jsonTag); idx > 0 { for j, c := range jsonTag { @@ -171,15 +163,12 @@ func (c *Config) UnmarshalJSON(data []byte) error { } } - // Get raw value from JSON rawValue, exists := raw[jsonFieldName] if !exists { continue } - // Check if this field is a time.Duration if field.Type() == reflect.TypeOf(time.Duration(0)) { - // Try to parse as string duration if strValue, ok := rawValue.(string); ok { duration, err := time.ParseDuration(strValue) if err != nil { @@ -190,13 +179,11 @@ func (c *Config) UnmarshalJSON(data []byte) error { return fmt.Errorf("field %s must be a duration string", jsonFieldName) } } else { - // For non-duration fields, unmarshal normally fieldData, err := json.Marshal(rawValue) if err != nil { return fmt.Errorf("failed to marshal field %s: %w", jsonFieldName, err) } - // Create a new instance of the field type if field.CanSet() { newVal := reflect.New(field.Type()) if err := json.Unmarshal(fieldData, newVal.Interface()); err != nil { diff --git a/proxy/pkg/proxy/server.go b/proxy/pkg/proxy/server.go index 32330da78..00f85ff34 100644 --- a/proxy/pkg/proxy/server.go +++ b/proxy/pkg/proxy/server.go @@ -95,13 +95,16 @@ func NewServer(config Config) (*Server, error) { // Set request data callback proxy.SetRequestCallback(func(data reverseproxy.RequestData) { log.WithFields(log.Fields{ - "service_id": data.ServiceID, - "host": data.Host, - "method": data.Method, - "path": data.Path, - "response_code": data.ResponseCode, - "duration_ms": data.DurationMs, - "source_ip": data.SourceIP, + "service_id": data.ServiceID, + "host": data.Host, + "method": data.Method, + "path": data.Path, + "response_code": data.ResponseCode, + "duration_ms": data.DurationMs, + "source_ip": data.SourceIP, + "auth_mechanism": data.AuthMechanism, + "user_id": data.UserID, + "auth_success": data.AuthSuccess, }).Info("Access log received") }) if err != nil { @@ -176,9 +179,9 @@ func (s *Server) Start() error { &reverseproxy.RouteConfig{ ID: "test", Domain: "test.netbird.io", - PathMappings: map[string]string{"/": "localhost:8181"}, + PathMappings: map[string]string{"/": "100.116.118.156:8181"}, AuthConfig: testAuthConfig, - SetupKey: "setup-key", + SetupKey: "88B2382A-93D2-47A9-A80F-D0055D741636", }); err != nil { log.Warn("Failed to add test route: ", err) }