mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
cleanup
This commit is contained in:
@@ -28,7 +28,6 @@ var (
|
|||||||
func init() {
|
func init() {
|
||||||
rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "path to JSON configuration file (optional, can use env vars instead)")
|
rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "path to JSON configuration file (optional, can use env vars instead)")
|
||||||
|
|
||||||
// Set version information
|
|
||||||
rootCmd.Version = version.Short()
|
rootCmd.Version = version.Short()
|
||||||
rootCmd.SetVersionTemplate("{{.Version}}\n")
|
rootCmd.SetVersionTemplate("{{.Version}}\n")
|
||||||
}
|
}
|
||||||
@@ -39,14 +38,12 @@ func Execute() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func run(cmd *cobra.Command, args []string) error {
|
func run(cmd *cobra.Command, args []string) error {
|
||||||
// Load configuration from file or environment variables
|
|
||||||
config, err := proxy.LoadFromFileOrEnv(configFile)
|
config, err := proxy.LoadFromFileOrEnv(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to load configuration: %v", err)
|
log.Fatalf("Failed to load configuration: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set log level
|
|
||||||
setupLogging(config.LogLevel)
|
setupLogging(config.LogLevel)
|
||||||
|
|
||||||
log.Infof("Starting Netbird Proxy - %s", version.Short())
|
log.Infof("Starting Netbird Proxy - %s", version.Short())
|
||||||
@@ -55,14 +52,12 @@ func run(cmd *cobra.Command, args []string) error {
|
|||||||
log.Infof("Listen Address: %s", config.ReverseProxy.ListenAddress)
|
log.Infof("Listen Address: %s", config.ReverseProxy.ListenAddress)
|
||||||
log.Infof("Log Level: %s", config.LogLevel)
|
log.Infof("Log Level: %s", config.LogLevel)
|
||||||
|
|
||||||
// Create server instance
|
|
||||||
server, err := proxy.NewServer(config)
|
server, err := proxy.NewServer(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to create server: %v", err)
|
log.Fatalf("Failed to create server: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start server in a goroutine
|
|
||||||
serverErrors := make(chan error, 1)
|
serverErrors := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
@@ -70,11 +65,9 @@ func run(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up signal handler for graceful shutdown
|
|
||||||
quit := make(chan os.Signal, 1)
|
quit := make(chan os.Signal, 1)
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Wait for either an error or shutdown signal
|
|
||||||
select {
|
select {
|
||||||
case err := <-serverErrors:
|
case err := <-serverErrors:
|
||||||
log.Fatalf("Server error: %v", err)
|
log.Fatalf("Server error: %v", err)
|
||||||
@@ -83,11 +76,9 @@ func run(cmd *cobra.Command, args []string) error {
|
|||||||
log.Infof("Received signal: %v", sig)
|
log.Infof("Received signal: %v", sig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create shutdown context with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Gracefully stop the server
|
|
||||||
if err := server.Stop(ctx); err != nil {
|
if err := server.Stop(ctx); err != nil {
|
||||||
log.Fatalf("Failed to stop server gracefully: %v", err)
|
log.Fatalf("Failed to stop server gracefully: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -98,13 +89,11 @@ func run(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupLogging(level string) {
|
func setupLogging(level string) {
|
||||||
// Set log format
|
|
||||||
log.SetFormatter(&log.TextFormatter{
|
log.SetFormatter(&log.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
TimestampFormat: "2006-01-02 15:04:05",
|
TimestampFormat: "2006-01-02 15:04:05",
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set log level
|
|
||||||
switch level {
|
switch level {
|
||||||
case "debug":
|
case "debug":
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultSessionCookieName is the default cookie name for session storage
|
|
||||||
DefaultSessionCookieName = "auth_session"
|
|
||||||
|
|
||||||
// ErrorInternalServer is the default internal server error message
|
|
||||||
ErrorInternalServer = "Internal Server Error"
|
|
||||||
)
|
|
||||||
@@ -18,7 +18,6 @@ func (c *BasicAuthConfig) Validate(r *http.Request) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use constant-time comparison to prevent timing attacks
|
|
||||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
|
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
|
||||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
|
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,5 @@ package methods
|
|||||||
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
|
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
|
||||||
// This just enables Bearer auth for a specific route
|
// This just enables Bearer auth for a specific route
|
||||||
type BearerConfig struct {
|
type BearerConfig struct {
|
||||||
// Enable bearer token authentication for this route
|
|
||||||
// Uses the global OIDC configuration from proxy Config
|
|
||||||
Enabled bool
|
Enabled bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ const (
|
|||||||
// PINConfig holds PIN authentication settings
|
// PINConfig holds PIN authentication settings
|
||||||
type PINConfig struct {
|
type PINConfig struct {
|
||||||
PIN string
|
PIN string
|
||||||
Header string // Header name (default: "X-PIN")
|
Header string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks PIN from the request header
|
// Validate checks PIN from the request header
|
||||||
@@ -28,6 +28,5 @@ func (c *PINConfig) Validate(r *http.Request) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use constant-time comparison to prevent timing attacks
|
|
||||||
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
|
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,26 +28,22 @@ type authResult struct {
|
|||||||
|
|
||||||
// ServeHTTP implements the http.Handler interface
|
// ServeHTTP implements the http.Handler interface
|
||||||
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// If no auth configured, allow request
|
|
||||||
if m.config.IsEmpty() {
|
if m.config.IsEmpty() {
|
||||||
m.allowWithoutAuth(w, r)
|
m.allowWithoutAuth(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to authenticate the request
|
|
||||||
result := m.authenticate(w, r)
|
result := m.authenticate(w, r)
|
||||||
if result == nil {
|
if result == nil {
|
||||||
// Authentication triggered a redirect (e.g., OIDC flow)
|
// Authentication triggered a redirect (e.g., OIDC flow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject if authentication failed
|
|
||||||
if !result.authenticated {
|
if !result.authenticated {
|
||||||
m.rejectRequest(w, r)
|
m.rejectRequest(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authentication successful - continue to next handler
|
|
||||||
m.continueWithAuth(w, r, result)
|
m.continueWithAuth(w, r, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,17 +61,14 @@ func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) {
|
|||||||
// authenticate attempts to authenticate the request using configured methods
|
// authenticate attempts to authenticate the request using configured methods
|
||||||
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
|
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
|
||||||
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
|
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
|
||||||
// Try Basic Auth
|
|
||||||
if result := m.tryBasicAuth(r); result.authenticated {
|
if result := m.tryBasicAuth(r); result.authenticated {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try PIN Auth
|
|
||||||
if result := m.tryPINAuth(r); result.authenticated {
|
if result := m.tryPINAuth(r); result.authenticated {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try Bearer/OIDC Auth
|
|
||||||
return m.tryBearerAuth(w, r)
|
return m.tryBearerAuth(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +87,6 @@ func (m *Middleware) tryBasicAuth(r *http.Request) *authResult {
|
|||||||
method: "basic",
|
method: "basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract username from Basic Auth
|
|
||||||
if username, _, ok := r.BasicAuth(); ok {
|
if username, _, ok := r.BasicAuth(); ok {
|
||||||
result.userID = username
|
result.userID = username
|
||||||
}
|
}
|
||||||
@@ -128,24 +120,20 @@ func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *auth
|
|||||||
|
|
||||||
cookieName := m.oidcHandler.SessionCookieName()
|
cookieName := m.oidcHandler.SessionCookieName()
|
||||||
|
|
||||||
// Handle auth token in query parameter (from OIDC callback)
|
|
||||||
if m.handleAuthTokenParameter(w, r, cookieName) {
|
if m.handleAuthTokenParameter(w, r, cookieName) {
|
||||||
return nil // Redirect occurred
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try session cookie
|
|
||||||
if result := m.trySessionCookie(r, cookieName); result.authenticated {
|
if result := m.trySessionCookie(r, cookieName); result.authenticated {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try Authorization header
|
|
||||||
if result := m.tryAuthorizationHeader(r); result.authenticated {
|
if result := m.tryAuthorizationHeader(r); result.authenticated {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// No valid auth - redirect to OIDC provider
|
|
||||||
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
|
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
|
||||||
return nil // Redirect occurred
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
|
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
|
||||||
@@ -161,7 +149,6 @@ func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Req
|
|||||||
"host": r.Host,
|
"host": r.Host,
|
||||||
}).Info("Found auth token in query parameter, setting cookie and redirecting")
|
}).Info("Found auth token in query parameter, setting cookie and redirecting")
|
||||||
|
|
||||||
// Validate the token before setting cookie
|
|
||||||
if !m.oidcHandler.ValidateJWT(authToken) {
|
if !m.oidcHandler.ValidateJWT(authToken) {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"route_id": m.routeID,
|
"route_id": m.routeID,
|
||||||
@@ -169,7 +156,6 @@ func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Req
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set session cookie
|
|
||||||
cookie := &http.Cookie{
|
cookie := &http.Cookie{
|
||||||
Name: cookieName,
|
Name: cookieName,
|
||||||
Value: authToken,
|
Value: authToken,
|
||||||
@@ -288,7 +274,7 @@ func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, re
|
|||||||
"path": r.URL.Path,
|
"path": r.URL.Path,
|
||||||
}).Debug("Authentication successful")
|
}).Debug("Authentication successful")
|
||||||
|
|
||||||
// Store auth info in headers for logging
|
// TODO: Find other means of auth logging than headers
|
||||||
r.Header.Set("X-Auth-Method", result.method)
|
r.Header.Set("X-Auth-Method", result.method)
|
||||||
r.Header.Set("X-Auth-User-ID", result.userID)
|
r.Header.Set("X-Auth-User-ID", result.userID)
|
||||||
|
|
||||||
@@ -299,7 +285,7 @@ func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, re
|
|||||||
// Wrap wraps an HTTP handler with authentication middleware
|
// Wrap wraps an HTTP handler with authentication middleware
|
||||||
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
|
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
|
||||||
if authConfig == nil {
|
if authConfig == nil {
|
||||||
authConfig = &Config{} // Empty config = no auth
|
authConfig = &Config{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Middleware{
|
return &Middleware{
|
||||||
|
|||||||
@@ -2,19 +2,16 @@ package oidc
|
|||||||
|
|
||||||
// Config holds the global OIDC/OAuth configuration
|
// Config holds the global OIDC/OAuth configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// OIDC Provider settings
|
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"`
|
||||||
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com")
|
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"`
|
||||||
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID
|
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"`
|
||||||
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients)
|
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"`
|
||||||
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback")
|
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"`
|
||||||
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"])
|
|
||||||
|
|
||||||
// JWT Validation settings
|
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"`
|
||||||
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys
|
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"`
|
||||||
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim
|
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"`
|
||||||
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims
|
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"`
|
||||||
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys
|
|
||||||
|
|
||||||
// Session settings
|
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"`
|
||||||
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ type Handler struct {
|
|||||||
|
|
||||||
// NewHandler creates a new OIDC handler
|
// NewHandler creates a new OIDC handler
|
||||||
func NewHandler(config *Config, stateStore *StateStore) *Handler {
|
func NewHandler(config *Config, stateStore *StateStore) *Handler {
|
||||||
// Initialize JWT validator
|
|
||||||
var jwtValidator *jwt.Validator
|
var jwtValidator *jwt.Validator
|
||||||
if config.JWTKeysLocation != "" {
|
if config.JWTKeysLocation != "" {
|
||||||
jwtValidator = jwt.NewValidator(
|
jwtValidator = jwt.NewValidator(
|
||||||
@@ -67,7 +66,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou
|
|||||||
scopes = []string{"openid", "profile", "email"}
|
scopes = []string{"openid", "profile", "email"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build authorization URL
|
|
||||||
authURL, err := url.Parse(h.config.ProviderURL)
|
authURL, err := url.Parse(h.config.ProviderURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Invalid OIDC provider URL")
|
log.WithError(err).Error("Invalid OIDC provider URL")
|
||||||
@@ -80,7 +78,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou
|
|||||||
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
|
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build query parameters
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Set("client_id", h.config.ClientID)
|
params.Set("client_id", h.config.ClientID)
|
||||||
params.Set("redirect_uri", h.config.RedirectURL)
|
params.Set("redirect_uri", h.config.RedirectURL)
|
||||||
@@ -88,8 +85,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou
|
|||||||
params.Set("scope", strings.Join(scopes, " "))
|
params.Set("scope", strings.Join(scopes, " "))
|
||||||
params.Set("state", state)
|
params.Set("state", state)
|
||||||
|
|
||||||
// Add audience parameter to get an access token for the API
|
|
||||||
// This ensures we get a proper JWT for the API audience, not just an ID token
|
|
||||||
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
|
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
|
||||||
params.Set("audience", h.config.JWTAudience[0])
|
params.Set("audience", h.config.JWTAudience[0])
|
||||||
}
|
}
|
||||||
@@ -103,7 +98,6 @@ func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, rou
|
|||||||
"state": state,
|
"state": state,
|
||||||
}).Info("Redirecting to OIDC provider for authentication")
|
}).Info("Redirecting to OIDC provider for authentication")
|
||||||
|
|
||||||
// Redirect user to identity provider login page
|
|
||||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ func (s *StateStore) Store(stateToken, originalURL, routeID string) {
|
|||||||
RouteID: routeID,
|
RouteID: routeID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up expired states
|
|
||||||
s.cleanup()
|
s.cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func NewLetsEncrypt(config LetsEncryptConfig) *LetsEncryptManager {
|
|||||||
HostPolicy: m.hostPolicy,
|
HostPolicy: m.hostPolicy,
|
||||||
Cache: autocert.DirCache(config.CertCacheDir),
|
Cache: autocert.DirCache(config.CertCacheDir),
|
||||||
Email: config.Email,
|
Email: config.Email,
|
||||||
RenewBefore: 0, // Use default
|
RenewBefore: 0, // Use default 30 days prior to expiration
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Let's Encrypt certificate manager initialized")
|
log.Info("Let's Encrypt certificate manager initialized")
|
||||||
@@ -71,8 +71,6 @@ func (m *LetsEncryptManager) RemoveDomain(domain string) {
|
|||||||
func (m *LetsEncryptManager) IssueCertificate(ctx context.Context, domain string) error {
|
func (m *LetsEncryptManager) IssueCertificate(ctx context.Context, domain string) error {
|
||||||
log.Infof("Issuing Let's Encrypt certificate for domain: %s", domain)
|
log.Infof("Issuing Let's Encrypt certificate for domain: %s", domain)
|
||||||
|
|
||||||
// Use GetCertificate to trigger certificate issuance
|
|
||||||
// This will go through the ACME challenge flow
|
|
||||||
hello := &tls.ClientHelloInfo{
|
hello := &tls.ClientHelloInfo{
|
||||||
ServerName: domain,
|
ServerName: domain,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,19 +54,16 @@ func (m *SelfSignedManager) IssueCertificate(ctx context.Context, domain string)
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
// Check if we already have a certificate for this domain
|
|
||||||
if _, exists := m.certificates[domain]; exists {
|
if _, exists := m.certificates[domain]; exists {
|
||||||
log.Debugf("Self-signed certificate already exists for domain: %s", domain)
|
log.Debugf("Self-signed certificate already exists for domain: %s", domain)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate self-signed certificate
|
|
||||||
cert, err := m.generateCertificate(domain)
|
cert, err := m.generateCertificate(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the certificate
|
|
||||||
m.certificates[domain] = cert
|
m.certificates[domain] = cert
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -94,7 +91,6 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer
|
|||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate certificate on-demand if not cached
|
|
||||||
log.Infof("Generating self-signed certificate on-demand for: %s", hello.ServerName)
|
log.Infof("Generating self-signed certificate on-demand for: %s", hello.ServerName)
|
||||||
|
|
||||||
newCert, err := m.generateCertificate(hello.ServerName)
|
newCert, err := m.generateCertificate(hello.ServerName)
|
||||||
@@ -102,7 +98,6 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache it
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.certificates[hello.ServerName] = newCert
|
m.certificates[hello.ServerName] = newCert
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
@@ -112,13 +107,11 @@ func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Cer
|
|||||||
|
|
||||||
// generateCertificate generates a self-signed certificate for a domain
|
// generateCertificate generates a self-signed certificate for a domain
|
||||||
func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate, error) {
|
func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate, error) {
|
||||||
// Generate private key
|
|
||||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create certificate template
|
|
||||||
notBefore := time.Now()
|
notBefore := time.Now()
|
||||||
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
|
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
|
||||||
|
|
||||||
@@ -141,13 +134,11 @@ func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate
|
|||||||
DNSNames: []string{domain},
|
DNSNames: []string{domain},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create self-signed certificate
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse certificate
|
|
||||||
cert, err := x509.ParseCertificate(certDER)
|
cert, err := x509.ParseCertificate(certDER)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||||
|
|||||||
@@ -53,18 +53,13 @@ func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
host = host[:idx]
|
host = host[:idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get auth info from headers set by auth middleware
|
// TODO: extract logging data
|
||||||
authMechanism := r.Header.Get("X-Auth-Method")
|
authMechanism := r.Header.Get("X-Auth-Method")
|
||||||
if authMechanism == "" {
|
if authMechanism == "" {
|
||||||
authMechanism = "none"
|
authMechanism = "none"
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := r.Header.Get("X-Auth-User-ID")
|
userID := r.Header.Get("X-Auth-User-ID")
|
||||||
|
|
||||||
// Determine auth success based on status code
|
|
||||||
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
||||||
|
|
||||||
// Extract source IP directly
|
|
||||||
sourceIP := extractSourceIP(r)
|
sourceIP := extractSourceIP(r)
|
||||||
|
|
||||||
data := RequestData{
|
data := RequestData{
|
||||||
@@ -89,29 +84,22 @@ func (p *Proxy) findRoute(host, path string) *routeEntry {
|
|||||||
p.mu.RLock()
|
p.mu.RLock()
|
||||||
defer p.mu.RUnlock()
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
// Strip port from host
|
|
||||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||||
host = host[:idx]
|
host = host[:idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
// O(1) lookup by host
|
|
||||||
routeConfig, exists := p.routes[host]
|
routeConfig, exists := p.routes[host]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build list of route entries sorted by path specificity
|
|
||||||
var entries []*routeEntry
|
var entries []*routeEntry
|
||||||
|
|
||||||
// Create entries for each path mapping
|
|
||||||
for routePath, target := range routeConfig.PathMappings {
|
for routePath, target := range routeConfig.PathMappings {
|
||||||
proxy := p.createProxy(routeConfig, target)
|
proxy := p.createProxy(routeConfig, target)
|
||||||
|
|
||||||
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
|
|
||||||
// This ensures consistent auth handling and logging
|
|
||||||
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
|
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
|
||||||
|
|
||||||
// Log auth configuration
|
|
||||||
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
||||||
var authType string
|
var authType string
|
||||||
if routeConfig.AuthConfig.BasicAuth != nil {
|
if routeConfig.AuthConfig.BasicAuth != nil {
|
||||||
@@ -169,11 +157,9 @@ func (p *Proxy) findRoute(host, path string) *routeEntry {
|
|||||||
|
|
||||||
// createProxy creates a reverse proxy for a target with the route's connection
|
// createProxy creates a reverse proxy for a target with the route's connection
|
||||||
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
||||||
// Parse target URL
|
|
||||||
targetURL, err := url.Parse("http://" + target)
|
targetURL, err := url.Parse("http://" + target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
||||||
// Return a proxy that returns 502
|
|
||||||
return &httputil.ReverseProxy{
|
return &httputil.ReverseProxy{
|
||||||
Director: func(req *http.Request) {},
|
Director: func(req *http.Request) {},
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
@@ -182,21 +168,18 @@ func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create reverse proxy
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
|
||||||
// Configure transport to use the provided connection (WireGuard, etc.)
|
|
||||||
proxy.Transport = &http.Transport{
|
proxy.Transport = &http.Transport{
|
||||||
DialContext: routeConfig.nbClient.DialContext,
|
DialContext: routeConfig.nbClient.DialContext,
|
||||||
MaxIdleConns: 1,
|
MaxIdleConns: 1,
|
||||||
MaxIdleConnsPerHost: 1,
|
MaxIdleConnsPerHost: 1,
|
||||||
IdleConnTimeout: 0, // Keep alive indefinitely
|
IdleConnTimeout: 0,
|
||||||
DisableKeepAlives: false,
|
DisableKeepAlives: false,
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom error handler
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
||||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||||
@@ -207,14 +190,12 @@ func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.R
|
|||||||
|
|
||||||
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
||||||
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
// Check if OIDC handler is available
|
|
||||||
if p.oidcHandler == nil {
|
if p.oidcHandler == nil {
|
||||||
log.Error("OIDC callback received but no OIDC handler configured")
|
log.Error("OIDC callback received but no OIDC handler configured")
|
||||||
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the OIDC handler's callback method
|
|
||||||
handler := p.oidcHandler.HandleCallback()
|
handler := p.oidcHandler.HandleCallback()
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ type Proxy struct {
|
|||||||
|
|
||||||
// New creates a new reverse proxy
|
// New creates a new reverse proxy
|
||||||
func New(config Config) (*Proxy, error) {
|
func New(config Config) (*Proxy, error) {
|
||||||
// Set defaults
|
|
||||||
if config.ListenAddress == "" {
|
if config.ListenAddress == "" {
|
||||||
config.ListenAddress = ":443"
|
config.ListenAddress = ":443"
|
||||||
}
|
}
|
||||||
@@ -38,22 +37,18 @@ func New(config Config) (*Proxy, error) {
|
|||||||
config.CertCacheDir = "./certs"
|
config.CertCacheDir = "./certs"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default cert mode
|
|
||||||
if config.CertMode == "" {
|
if config.CertMode == "" {
|
||||||
config.CertMode = "letsencrypt"
|
config.CertMode = "letsencrypt"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate config based on cert mode
|
|
||||||
if config.CertMode == "letsencrypt" && config.TLSEmail == "" {
|
if config.CertMode == "letsencrypt" && config.TLSEmail == "" {
|
||||||
return nil, fmt.Errorf("TLSEmail is required for letsencrypt mode")
|
return nil, fmt.Errorf("TLSEmail is required for letsencrypt mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default OIDC session cookie name if not provided
|
|
||||||
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
|
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
|
||||||
config.OIDCConfig.SessionCookieName = "auth_session"
|
config.OIDCConfig.SessionCookieName = "auth_session"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize certificate manager based on mode
|
|
||||||
var certMgr certmanager.Manager
|
var certMgr certmanager.Manager
|
||||||
if config.CertMode == "selfsigned" {
|
if config.CertMode == "selfsigned" {
|
||||||
// HTTPS with self-signed certificates (for local testing)
|
// HTTPS with self-signed certificates (for local testing)
|
||||||
@@ -73,8 +68,6 @@ func New(config Config) (*Proxy, error) {
|
|||||||
isRunning: false,
|
isRunning: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize OIDC handler if OIDC is configured
|
|
||||||
// The handler internally creates and manages its own state store
|
|
||||||
if config.OIDCConfig != nil {
|
if config.OIDCConfig != nil {
|
||||||
stateStore := oidc.NewStateStore()
|
stateStore := oidc.NewStateStore()
|
||||||
p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore)
|
p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore)
|
||||||
@@ -93,28 +86,25 @@ func (p *Proxy) Start() error {
|
|||||||
p.isRunning = true
|
p.isRunning = true
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
// Build the main HTTP handler
|
|
||||||
handler := p.buildHandler()
|
handler := p.buildHandler()
|
||||||
|
|
||||||
return p.startHTTPS(handler)
|
return p.startHTTPS(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// startHTTPS starts the proxy with HTTPS (non-blocking)
|
// startHTTPS starts the proxy with HTTPS
|
||||||
func (p *Proxy) startHTTPS(handler http.Handler) error {
|
func (p *Proxy) startHTTPS(handler http.Handler) error {
|
||||||
// Start HTTP server for ACME challenges (Let's Encrypt HTTP-01)
|
|
||||||
p.httpServer = &http.Server{
|
p.httpServer = &http.Server{
|
||||||
Addr: p.config.HTTPListenAddress,
|
Addr: p.config.HTTPListenAddress,
|
||||||
Handler: p.certManager.HTTPHandler(nil),
|
Handler: p.certManager.HTTPHandler(nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Infof("Starting HTTP server for ACME challenges on %s", p.config.HTTPListenAddress)
|
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
|
||||||
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
log.Errorf("HTTP server error: %v", err)
|
log.Errorf("HTTP server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start HTTPS server in background
|
|
||||||
p.server = &http.Server{
|
p.server = &http.Server{
|
||||||
Addr: p.config.ListenAddress,
|
Addr: p.config.ListenAddress,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
@@ -143,14 +133,12 @@ func (p *Proxy) Stop(ctx context.Context) error {
|
|||||||
|
|
||||||
log.Info("Stopping reverse proxy server...")
|
log.Info("Stopping reverse proxy server...")
|
||||||
|
|
||||||
// Stop HTTP server (for ACME challenges)
|
|
||||||
if p.httpServer != nil {
|
if p.httpServer != nil {
|
||||||
if err := p.httpServer.Shutdown(ctx); err != nil {
|
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||||
log.Errorf("Error shutting down HTTP server: %v", err)
|
log.Errorf("Error shutting down HTTP server: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop main server
|
|
||||||
if p.server != nil {
|
if p.server != nil {
|
||||||
if err := p.server.Shutdown(ctx); err != nil {
|
if err := p.server.Shutdown(ctx); err != nil {
|
||||||
return fmt.Errorf("error shutting down server: %w", err)
|
return fmt.Errorf("error shutting down server: %w", err)
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ func (p *Proxy) AddRoute(route *RouteConfig) error {
|
|||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
// Check if route already exists for this domain
|
|
||||||
if _, exists := p.routes[route.Domain]; exists {
|
if _, exists := p.routes[route.Domain]; exists {
|
||||||
return fmt.Errorf("route for domain %s already exists", route.Domain)
|
return fmt.Errorf("route for domain %s already exists", route.Domain)
|
||||||
}
|
}
|
||||||
@@ -54,10 +53,8 @@ func (p *Proxy) AddRoute(route *RouteConfig) error {
|
|||||||
|
|
||||||
route.nbClient = client
|
route.nbClient = client
|
||||||
|
|
||||||
// Add route with domain as key
|
|
||||||
p.routes[route.Domain] = route
|
p.routes[route.Domain] = route
|
||||||
|
|
||||||
// Register domain with certificate manager
|
|
||||||
p.certManager.AddDomain(route.Domain)
|
p.certManager.AddDomain(route.Domain)
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
@@ -66,13 +63,13 @@ func (p *Proxy) AddRoute(route *RouteConfig) error {
|
|||||||
"paths": len(route.PathMappings),
|
"paths": len(route.PathMappings),
|
||||||
}).Info("Added route")
|
}).Info("Added route")
|
||||||
|
|
||||||
// Eagerly issue certificate in background
|
|
||||||
go func(domain string) {
|
go func(domain string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := p.certManager.IssueCertificate(ctx, domain); err != nil {
|
if err := p.certManager.IssueCertificate(ctx, domain); err != nil {
|
||||||
log.Errorf("Failed to issue certificate: %v", err)
|
log.Errorf("Failed to issue certificate: %v", err)
|
||||||
|
// TODO: Better error feedback mechanism
|
||||||
}
|
}
|
||||||
}(route.Domain)
|
}(route.Domain)
|
||||||
|
|
||||||
@@ -84,15 +81,12 @@ func (p *Proxy) RemoveRoute(domain string) error {
|
|||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
// Check if route exists
|
|
||||||
if _, exists := p.routes[domain]; !exists {
|
if _, exists := p.routes[domain]; !exists {
|
||||||
return fmt.Errorf("route for domain %s not found", domain)
|
return fmt.Errorf("route for domain %s not found", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove route
|
|
||||||
delete(p.routes, domain)
|
delete(p.routes, domain)
|
||||||
|
|
||||||
// Unregister domain from certificate manager
|
|
||||||
p.certManager.RemoveDomain(domain)
|
p.certManager.RemoveDomain(domain)
|
||||||
|
|
||||||
log.Infof("Removed route for domain: %s", domain)
|
log.Infof("Removed route for domain: %s", domain)
|
||||||
@@ -114,12 +108,10 @@ func (p *Proxy) UpdateRoute(route *RouteConfig) error {
|
|||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
// Check if route exists for this domain
|
|
||||||
if _, exists := p.routes[route.Domain]; !exists {
|
if _, exists := p.routes[route.Domain]; !exists {
|
||||||
return fmt.Errorf("route for domain %s not found", route.Domain)
|
return fmt.Errorf("route for domain %s not found", route.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update route using domain as key
|
|
||||||
p.routes[route.Domain] = route
|
p.routes[route.Domain] = route
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
package errors
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// Configuration errors
|
|
||||||
|
|
||||||
func NewConfigInvalid(message string) *AppError {
|
|
||||||
return New(CodeConfigInvalid, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfigNotFound(path string) *AppError {
|
|
||||||
return New(CodeConfigNotFound, fmt.Sprintf("configuration file not found: %s", path))
|
|
||||||
}
|
|
||||||
|
|
||||||
func WrapConfigParseFailed(err error, path string) *AppError {
|
|
||||||
return Wrap(CodeConfigParseFailed, fmt.Sprintf("failed to parse configuration file: %s", path), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server errors
|
|
||||||
|
|
||||||
func NewServerStartFailed(err error, reason string) *AppError {
|
|
||||||
return Wrap(CodeServerStartFailed, fmt.Sprintf("server start failed: %s", reason), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerStopFailed(err error) *AppError {
|
|
||||||
return Wrap(CodeServerStopFailed, "server shutdown failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerAlreadyRunning() *AppError {
|
|
||||||
return New(CodeServerAlreadyRunning, "server is already running")
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerNotRunning() *AppError {
|
|
||||||
return New(CodeServerNotRunning, "server is not running")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy errors
|
|
||||||
|
|
||||||
func NewProxyBackendUnavailable(backend string, err error) *AppError {
|
|
||||||
return Wrap(CodeProxyBackendUnavailable, fmt.Sprintf("backend unavailable: %s", backend), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyTimeout(backend string) *AppError {
|
|
||||||
return New(CodeProxyTimeout, fmt.Sprintf("request to backend timed out: %s", backend))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyInvalidTarget(target string, err error) *AppError {
|
|
||||||
return Wrap(CodeProxyInvalidTarget, fmt.Sprintf("invalid proxy target: %s", target), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Network errors
|
|
||||||
|
|
||||||
func NewNetworkTimeout(operation string) *AppError {
|
|
||||||
return New(CodeNetworkTimeout, fmt.Sprintf("network timeout: %s", operation))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNetworkUnreachable(host string) *AppError {
|
|
||||||
return New(CodeNetworkUnreachable, fmt.Sprintf("network unreachable: %s", host))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNetworkRefused(host string) *AppError {
|
|
||||||
return New(CodeNetworkRefused, fmt.Sprintf("connection refused: %s", host))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Internal errors
|
|
||||||
|
|
||||||
func NewInternalError(message string) *AppError {
|
|
||||||
return New(CodeInternalError, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WrapInternalError(err error, message string) *AppError {
|
|
||||||
return Wrap(CodeInternalError, message, err)
|
|
||||||
}
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Error codes for categorizing errors
|
|
||||||
type Code string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Configuration errors
|
|
||||||
CodeConfigInvalid Code = "CONFIG_INVALID"
|
|
||||||
CodeConfigNotFound Code = "CONFIG_NOT_FOUND"
|
|
||||||
CodeConfigParseFailed Code = "CONFIG_PARSE_FAILED"
|
|
||||||
|
|
||||||
// Server errors
|
|
||||||
CodeServerStartFailed Code = "SERVER_START_FAILED"
|
|
||||||
CodeServerStopFailed Code = "SERVER_STOP_FAILED"
|
|
||||||
CodeServerAlreadyRunning Code = "SERVER_ALREADY_RUNNING"
|
|
||||||
CodeServerNotRunning Code = "SERVER_NOT_RUNNING"
|
|
||||||
|
|
||||||
// Proxy errors
|
|
||||||
CodeProxyBackendUnavailable Code = "PROXY_BACKEND_UNAVAILABLE"
|
|
||||||
CodeProxyTimeout Code = "PROXY_TIMEOUT"
|
|
||||||
CodeProxyInvalidTarget Code = "PROXY_INVALID_TARGET"
|
|
||||||
|
|
||||||
// Network errors
|
|
||||||
CodeNetworkTimeout Code = "NETWORK_TIMEOUT"
|
|
||||||
CodeNetworkUnreachable Code = "NETWORK_UNREACHABLE"
|
|
||||||
CodeNetworkRefused Code = "NETWORK_REFUSED"
|
|
||||||
|
|
||||||
// Internal errors
|
|
||||||
CodeInternalError Code = "INTERNAL_ERROR"
|
|
||||||
CodeUnknownError Code = "UNKNOWN_ERROR"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AppError represents a structured application error
|
|
||||||
type AppError struct {
|
|
||||||
Code Code // Error code for categorization
|
|
||||||
Message string // Human-readable error message
|
|
||||||
Cause error // Underlying error (if any)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements the error interface
|
|
||||||
func (e *AppError) Error() string {
|
|
||||||
if e.Cause != nil {
|
|
||||||
return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrap returns the underlying error (for errors.Is and errors.As)
|
|
||||||
func (e *AppError) Unwrap() error {
|
|
||||||
return e.Cause
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is checks if the error matches the target
|
|
||||||
func (e *AppError) Is(target error) bool {
|
|
||||||
t, ok := target.(*AppError)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return e.Code == t.Code
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new AppError
|
|
||||||
func New(code Code, message string) *AppError {
|
|
||||||
return &AppError{
|
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap wraps an existing error with additional context
|
|
||||||
func Wrap(code Code, message string, cause error) *AppError {
|
|
||||||
return &AppError{
|
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
Cause: cause,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrapf wraps an error with a formatted message
|
|
||||||
func Wrapf(code Code, cause error, format string, args ...interface{}) *AppError {
|
|
||||||
return &AppError{
|
|
||||||
Code: code,
|
|
||||||
Message: fmt.Sprintf(format, args...),
|
|
||||||
Cause: cause,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCode extracts the error code from an error
|
|
||||||
func GetCode(err error) Code {
|
|
||||||
var appErr *AppError
|
|
||||||
if errors.As(err, &appErr) {
|
|
||||||
return appErr.Code
|
|
||||||
}
|
|
||||||
return CodeUnknownError
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasCode checks if an error has a specific code
|
|
||||||
func HasCode(err error, code Code) bool {
|
|
||||||
return GetCode(err) == code
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsConfigError checks if an error is configuration-related
|
|
||||||
func IsConfigError(err error) bool {
|
|
||||||
code := GetCode(err)
|
|
||||||
return code == CodeConfigInvalid ||
|
|
||||||
code == CodeConfigNotFound ||
|
|
||||||
code == CodeConfigParseFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsServerError checks if an error is server-related
|
|
||||||
func IsServerError(err error) bool {
|
|
||||||
code := GetCode(err)
|
|
||||||
return code == CodeServerStartFailed ||
|
|
||||||
code == CodeServerStopFailed ||
|
|
||||||
code == CodeServerAlreadyRunning ||
|
|
||||||
code == CodeServerNotRunning
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsProxyError checks if an error is proxy-related
|
|
||||||
func IsProxyError(err error) bool {
|
|
||||||
code := GetCode(err)
|
|
||||||
return code == CodeProxyBackendUnavailable ||
|
|
||||||
code == CodeProxyTimeout ||
|
|
||||||
code == CodeProxyInvalidTarget
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNetworkError checks if an error is network-related
|
|
||||||
func IsNetworkError(err error) bool {
|
|
||||||
code := GetCode(err)
|
|
||||||
return code == CodeNetworkTimeout ||
|
|
||||||
code == CodeNetworkUnreachable ||
|
|
||||||
code == CodeNetworkRefused
|
|
||||||
}
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAppError_Error(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *AppError
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "error without cause",
|
|
||||||
err: New(CodeConfigInvalid, "invalid configuration"),
|
|
||||||
expected: "[CONFIG_INVALID] invalid configuration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with cause",
|
|
||||||
err: Wrap(CodeServerStartFailed, "failed to bind port", errors.New("address already in use")),
|
|
||||||
expected: "[SERVER_START_FAILED] failed to bind port: address already in use",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.err.Error(); got != tt.expected {
|
|
||||||
t.Errorf("Error() = %v, want %v", got, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetCode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
expected Code
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "app error",
|
|
||||||
err: New(CodeConfigInvalid, "test"),
|
|
||||||
expected: CodeConfigInvalid,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "wrapped app error",
|
|
||||||
err: Wrap(CodeServerStartFailed, "test", errors.New("cause")),
|
|
||||||
expected: CodeServerStartFailed,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "standard error",
|
|
||||||
err: errors.New("standard error"),
|
|
||||||
expected: CodeUnknownError,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := GetCode(tt.err); got != tt.expected {
|
|
||||||
t.Errorf("GetCode() = %v, want %v", got, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasCode(t *testing.T) {
|
|
||||||
err := New(CodeConfigInvalid, "invalid config")
|
|
||||||
|
|
||||||
if !HasCode(err, CodeConfigInvalid) {
|
|
||||||
t.Error("HasCode() should return true for matching code")
|
|
||||||
}
|
|
||||||
|
|
||||||
if HasCode(err, CodeServerStartFailed) {
|
|
||||||
t.Error("HasCode() should return false for non-matching code")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsConfigError(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "config invalid error",
|
|
||||||
err: New(CodeConfigInvalid, "test"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "config not found error",
|
|
||||||
err: New(CodeConfigNotFound, "test"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "server error",
|
|
||||||
err: New(CodeServerStartFailed, "test"),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := IsConfigError(tt.err); got != tt.expected {
|
|
||||||
t.Errorf("IsConfigError() = %v, want %v", got, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorUnwrap(t *testing.T) {
|
|
||||||
cause := errors.New("root cause")
|
|
||||||
err := Wrap(CodeInternalError, "wrapped error", cause)
|
|
||||||
|
|
||||||
unwrapped := errors.Unwrap(err)
|
|
||||||
if unwrapped != cause {
|
|
||||||
t.Errorf("Unwrap() = %v, want %v", unwrapped, cause)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorIs(t *testing.T) {
|
|
||||||
err1 := New(CodeConfigInvalid, "test1")
|
|
||||||
err2 := New(CodeConfigInvalid, "test2")
|
|
||||||
err3 := New(CodeServerStartFailed, "test3")
|
|
||||||
|
|
||||||
if !errors.Is(err1, err2) {
|
|
||||||
t.Error("errors.Is() should return true for same error code")
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err1, err3) {
|
|
||||||
t.Error("errors.Is() should return false for different error codes")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCommonConstructors(t *testing.T) {
|
|
||||||
t.Run("NewConfigNotFound", func(t *testing.T) {
|
|
||||||
err := NewConfigNotFound("/path/to/config")
|
|
||||||
if GetCode(err) != CodeConfigNotFound {
|
|
||||||
t.Error("NewConfigNotFound should create CONFIG_NOT_FOUND error")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("NewServerAlreadyRunning", func(t *testing.T) {
|
|
||||||
err := NewServerAlreadyRunning()
|
|
||||||
if GetCode(err) != CodeServerAlreadyRunning {
|
|
||||||
t.Error("NewServerAlreadyRunning should create SERVER_ALREADY_RUNNING error")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("NewProxyBackendUnavailable", func(t *testing.T) {
|
|
||||||
cause := errors.New("connection refused")
|
|
||||||
err := NewProxyBackendUnavailable("http://backend", cause)
|
|
||||||
if GetCode(err) != CodeProxyBackendUnavailable {
|
|
||||||
t.Error("NewProxyBackendUnavailable should create PROXY_BACKEND_UNAVAILABLE error")
|
|
||||||
}
|
|
||||||
if !errors.Is(err.Unwrap(), cause) {
|
|
||||||
t.Error("NewProxyBackendUnavailable should wrap the cause")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -77,7 +77,6 @@ func (s *Server) Start() error {
|
|||||||
return fmt.Errorf("failed to listen: %w", err)
|
return fmt.Errorf("failed to listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure gRPC server with keepalive
|
|
||||||
s.grpcServer = grpc.NewServer(
|
s.grpcServer = grpc.NewServer(
|
||||||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
@@ -114,7 +113,6 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||||||
|
|
||||||
log.Info("Stopping gRPC server...")
|
log.Info("Stopping gRPC server...")
|
||||||
|
|
||||||
// Cancel all active streams
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for _, streamCtx := range s.streams {
|
for _, streamCtx := range s.streams {
|
||||||
streamCtx.cancel()
|
streamCtx.cancel()
|
||||||
@@ -123,7 +121,6 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||||||
s.streams = make(map[string]*StreamContext)
|
s.streams = make(map[string]*StreamContext)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Graceful stop with timeout
|
|
||||||
stopped := make(chan struct{})
|
stopped := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.grpcServer.GracefulStop()
|
s.grpcServer.GracefulStop()
|
||||||
@@ -154,7 +151,6 @@ func (s *Server) Stream(stream pb.ProxyService_StreamServer) error {
|
|||||||
|
|
||||||
controlID := fmt.Sprintf("control-%d", time.Now().Unix())
|
controlID := fmt.Sprintf("control-%d", time.Now().Unix())
|
||||||
|
|
||||||
// Create stream context
|
|
||||||
streamCtx := &StreamContext{
|
streamCtx := &StreamContext{
|
||||||
stream: stream,
|
stream: stream,
|
||||||
sendChan: make(chan *pb.ProxyMessage, 100),
|
sendChan: make(chan *pb.ProxyMessage, 100),
|
||||||
@@ -163,22 +159,18 @@ func (s *Server) Stream(stream pb.ProxyService_StreamServer) error {
|
|||||||
controlID: controlID,
|
controlID: controlID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register stream
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.streams[controlID] = streamCtx
|
s.streams[controlID] = streamCtx
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
log.Infof("Control service connected: %s", controlID)
|
log.Infof("Control service connected: %s", controlID)
|
||||||
|
|
||||||
// Start goroutine to send ProxyMessages to control service
|
|
||||||
sendDone := make(chan error, 1)
|
sendDone := make(chan error, 1)
|
||||||
go s.sendLoop(streamCtx, sendDone)
|
go s.sendLoop(streamCtx, sendDone)
|
||||||
|
|
||||||
// Start goroutine to receive ControlMessages from control service
|
|
||||||
recvDone := make(chan error, 1)
|
recvDone := make(chan error, 1)
|
||||||
go s.receiveLoop(streamCtx, recvDone)
|
go s.receiveLoop(streamCtx, recvDone)
|
||||||
|
|
||||||
// Wait for either send or receive to complete
|
|
||||||
select {
|
select {
|
||||||
case err := <-sendDone:
|
case err := <-sendDone:
|
||||||
log.Infof("Control service %s send loop ended: %v", controlID, err)
|
log.Infof("Control service %s send loop ended: %v", controlID, err)
|
||||||
@@ -202,7 +194,6 @@ func (s *Server) sendLoop(streamCtx *StreamContext, done chan<- error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send ProxyMessage to control service
|
|
||||||
if err := streamCtx.stream.Send(msg); err != nil {
|
if err := streamCtx.stream.Send(msg); err != nil {
|
||||||
log.Errorf("Failed to send message to control service: %v", err)
|
log.Errorf("Failed to send message to control service: %v", err)
|
||||||
done <- err
|
done <- err
|
||||||
@@ -219,7 +210,6 @@ func (s *Server) sendLoop(streamCtx *StreamContext, done chan<- error) {
|
|||||||
// receiveLoop handles receiving ControlMessages from the control service
|
// receiveLoop handles receiving ControlMessages from the control service
|
||||||
func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) {
|
func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) {
|
||||||
for {
|
for {
|
||||||
// Receive ControlMessage from control service (client)
|
|
||||||
controlMsg, err := streamCtx.stream.Recv()
|
controlMsg, err := streamCtx.stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Stream receive error: %v", err)
|
log.Debugf("Stream receive error: %v", err)
|
||||||
@@ -227,7 +217,6 @@ func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle different ControlMessage types
|
|
||||||
switch m := controlMsg.Message.(type) {
|
switch m := controlMsg.Message.(type) {
|
||||||
case *pb.ControlMessage_Event:
|
case *pb.ControlMessage_Event:
|
||||||
if s.handler != nil {
|
if s.handler != nil {
|
||||||
@@ -271,7 +260,6 @@ func (s *Server) SendProxyMessage(msg *pb.ProxyMessage) {
|
|||||||
for _, streamCtx := range s.streams {
|
for _, streamCtx := range s.streams {
|
||||||
select {
|
select {
|
||||||
case streamCtx.sendChan <- msg:
|
case streamCtx.sendChan <- msg:
|
||||||
// Message queued successfully
|
|
||||||
default:
|
default:
|
||||||
log.Warn("Send channel full, dropping message")
|
log.Warn("Send channel full, dropping message")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ type Config struct {
|
|||||||
// LogLevel sets the logging verbosity (debug, info, warn, error)
|
// LogLevel sets the logging verbosity (debug, info, warn, error)
|
||||||
LogLevel string `env:"NB_PROXY_LOG_LEVEL" envDefault:"info" json:"log_level"`
|
LogLevel string `env:"NB_PROXY_LOG_LEVEL" envDefault:"info" json:"log_level"`
|
||||||
|
|
||||||
// GRPCListenAddress is the address for the gRPC control server (empty to disable)
|
// GRPCListenAddress is the address for the gRPC control server
|
||||||
GRPCListenAddress string `env:"NB_PROXY_GRPC_LISTEN_ADDRESS" envDefault:":50051" json:"grpc_listen_address"`
|
GRPCListenAddress string `env:"NB_PROXY_GRPC_LISTEN_ADDRESS" envDefault:":50051" json:"grpc_listen_address"`
|
||||||
|
|
||||||
// ProxyID is a unique identifier for this proxy instance
|
// ProxyID is a unique identifier for this proxy instance
|
||||||
@@ -111,7 +111,6 @@ func LoadFromFile(path string) (Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadFromFileOrEnv loads configuration from a file if path is provided, otherwise from environment variables
|
// LoadFromFileOrEnv loads configuration from a file if path is provided, otherwise from environment variables
|
||||||
// Environment variables will override file-based configuration if both are present
|
|
||||||
func LoadFromFileOrEnv(configPath string) (Config, error) {
|
func LoadFromFileOrEnv(configPath string) (Config, error) {
|
||||||
var cfg Config
|
var cfg Config
|
||||||
|
|
||||||
@@ -123,7 +122,6 @@ func LoadFromFileOrEnv(configPath string) (Config, error) {
|
|||||||
}
|
}
|
||||||
cfg = fileCfg
|
cfg = fileCfg
|
||||||
} else {
|
} else {
|
||||||
// Parse environment variables (will override file config with any set env vars)
|
|
||||||
if err := env.Parse(&cfg); err != nil {
|
if err := env.Parse(&cfg); err != nil {
|
||||||
return Config{}, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
|
return Config{}, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
|
||||||
}
|
}
|
||||||
@@ -137,30 +135,24 @@ func LoadFromFileOrEnv(configPath string) (Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements custom JSON unmarshaling with automatic duration parsing
|
// UnmarshalJSON implements custom JSON unmarshaling with automatic duration parsing
|
||||||
// Uses reflection to find all time.Duration fields and parse them from string
|
|
||||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||||
// First unmarshal into a map to get raw values
|
|
||||||
var raw map[string]interface{}
|
var raw map[string]interface{}
|
||||||
if err := json.Unmarshal(data, &raw); err != nil {
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get reflection value and type
|
|
||||||
val := reflect.ValueOf(c).Elem()
|
val := reflect.ValueOf(c).Elem()
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
|
||||||
// Iterate through all fields
|
|
||||||
for i := 0; i < val.NumField(); i++ {
|
for i := 0; i < val.NumField(); i++ {
|
||||||
field := val.Field(i)
|
field := val.Field(i)
|
||||||
fieldType := typ.Field(i)
|
fieldType := typ.Field(i)
|
||||||
|
|
||||||
// Get JSON tag name
|
|
||||||
jsonTag := fieldType.Tag.Get("json")
|
jsonTag := fieldType.Tag.Get("json")
|
||||||
if jsonTag == "" || jsonTag == "-" {
|
if jsonTag == "" || jsonTag == "-" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse tag to get field name (handle omitempty, etc.)
|
|
||||||
jsonFieldName := jsonTag
|
jsonFieldName := jsonTag
|
||||||
if idx := len(jsonTag); idx > 0 {
|
if idx := len(jsonTag); idx > 0 {
|
||||||
for j, c := range jsonTag {
|
for j, c := range jsonTag {
|
||||||
@@ -171,15 +163,12 @@ func (c *Config) UnmarshalJSON(data []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get raw value from JSON
|
|
||||||
rawValue, exists := raw[jsonFieldName]
|
rawValue, exists := raw[jsonFieldName]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this field is a time.Duration
|
|
||||||
if field.Type() == reflect.TypeOf(time.Duration(0)) {
|
if field.Type() == reflect.TypeOf(time.Duration(0)) {
|
||||||
// Try to parse as string duration
|
|
||||||
if strValue, ok := rawValue.(string); ok {
|
if strValue, ok := rawValue.(string); ok {
|
||||||
duration, err := time.ParseDuration(strValue)
|
duration, err := time.ParseDuration(strValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -190,13 +179,11 @@ func (c *Config) UnmarshalJSON(data []byte) error {
|
|||||||
return fmt.Errorf("field %s must be a duration string", jsonFieldName)
|
return fmt.Errorf("field %s must be a duration string", jsonFieldName)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For non-duration fields, unmarshal normally
|
|
||||||
fieldData, err := json.Marshal(rawValue)
|
fieldData, err := json.Marshal(rawValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal field %s: %w", jsonFieldName, err)
|
return fmt.Errorf("failed to marshal field %s: %w", jsonFieldName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new instance of the field type
|
|
||||||
if field.CanSet() {
|
if field.CanSet() {
|
||||||
newVal := reflect.New(field.Type())
|
newVal := reflect.New(field.Type())
|
||||||
if err := json.Unmarshal(fieldData, newVal.Interface()); err != nil {
|
if err := json.Unmarshal(fieldData, newVal.Interface()); err != nil {
|
||||||
|
|||||||
@@ -95,13 +95,16 @@ func NewServer(config Config) (*Server, error) {
|
|||||||
// Set request data callback
|
// Set request data callback
|
||||||
proxy.SetRequestCallback(func(data reverseproxy.RequestData) {
|
proxy.SetRequestCallback(func(data reverseproxy.RequestData) {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"service_id": data.ServiceID,
|
"service_id": data.ServiceID,
|
||||||
"host": data.Host,
|
"host": data.Host,
|
||||||
"method": data.Method,
|
"method": data.Method,
|
||||||
"path": data.Path,
|
"path": data.Path,
|
||||||
"response_code": data.ResponseCode,
|
"response_code": data.ResponseCode,
|
||||||
"duration_ms": data.DurationMs,
|
"duration_ms": data.DurationMs,
|
||||||
"source_ip": data.SourceIP,
|
"source_ip": data.SourceIP,
|
||||||
|
"auth_mechanism": data.AuthMechanism,
|
||||||
|
"user_id": data.UserID,
|
||||||
|
"auth_success": data.AuthSuccess,
|
||||||
}).Info("Access log received")
|
}).Info("Access log received")
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -176,9 +179,9 @@ func (s *Server) Start() error {
|
|||||||
&reverseproxy.RouteConfig{
|
&reverseproxy.RouteConfig{
|
||||||
ID: "test",
|
ID: "test",
|
||||||
Domain: "test.netbird.io",
|
Domain: "test.netbird.io",
|
||||||
PathMappings: map[string]string{"/": "localhost:8181"},
|
PathMappings: map[string]string{"/": "100.116.118.156:8181"},
|
||||||
AuthConfig: testAuthConfig,
|
AuthConfig: testAuthConfig,
|
||||||
SetupKey: "setup-key",
|
SetupKey: "88B2382A-93D2-47A9-A80F-D0055D741636",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Warn("Failed to add test route: ", err)
|
log.Warn("Failed to add test route: ", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user