mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Merge branch 'fix/log-formatter' into fix/http-redirect
This commit is contained in:
@@ -39,10 +39,10 @@ var (
|
||||
addr string
|
||||
proxyDomain string
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
@@ -123,7 +123,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
_ = util.InitLogger(logger, level, util.LogConsole)
|
||||
|
||||
log.Infof("configured log level: %s", level)
|
||||
logger.Infof("configured log level: %s", level)
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
@@ -171,7 +171,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
defer stop()
|
||||
|
||||
if err := srv.ListenAndServe(ctx, addr); err != nil {
|
||||
log.Error(err)
|
||||
logger.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -84,7 +84,7 @@ func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
|
||||
// nil lockFile means locking is not supported (non-unix).
|
||||
if lockFile == nil {
|
||||
return func() {}, nil
|
||||
return func() { /* no-op: locking unsupported on this platform */ }, nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
@@ -98,5 +98,5 @@ type noopLocker struct{}
|
||||
|
||||
// Lock is a no-op that always succeeds immediately.
|
||||
func (noopLocker) Lock(context.Context, string) (func(), error) {
|
||||
return func() {}, nil
|
||||
return func() { /* no-op: locker disabled */ }, nil
|
||||
}
|
||||
|
||||
@@ -90,10 +90,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains[host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
@@ -103,115 +101,160 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
// Check for error from OAuth callback (e.g., access denied)
|
||||
if errCode := r.URL.Query().Get("error"); errCode != "" {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for an existing session cookie (contains JWT)
|
||||
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if mw.forwardWithSessionCookie(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
config, exists := mw.domains[host]
|
||||
return config, exists
|
||||
}
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
errCode := r.URL.Query().Get("error")
|
||||
if errCode == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithSessionCookie checks for a valid session cookie and, if found,
|
||||
// sets the user identity on the request context and forwards to the next handler.
|
||||
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
cookie, err := r.Cookie(auth.SessionCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// handleAuthenticatedToken validates the token, handles denied access, and on
|
||||
// success sets a session cookie and redirects to the original URL.
|
||||
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
|
||||
@@ -83,6 +83,10 @@ func (c *Client) printHealth(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
c.printHealthClients(data)
|
||||
}
|
||||
|
||||
func (c *Client) printHealthClients(data map[string]any) {
|
||||
clients, ok := data["clients"].(map[string]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
return
|
||||
|
||||
71
proxy/internal/debug/client_test.go
Normal file
71
proxy/internal/debug/client_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "1h30m",
|
||||
"management_connected": true,
|
||||
"all_clients_healthy": true,
|
||||
"certs_total": float64(3),
|
||||
"certs_ready": float64(2),
|
||||
"certs_pending": float64(1),
|
||||
"certs_failed": float64(0),
|
||||
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
|
||||
"certs_pending_domains": []any{"c.example.com"},
|
||||
"clients": map[string]any{
|
||||
"acc-1": map[string]any{
|
||||
"healthy": true,
|
||||
"management_connected": true,
|
||||
"signal_connected": true,
|
||||
"relays_connected": float64(1),
|
||||
"relays_total": float64(2),
|
||||
"peers_connected": float64(3),
|
||||
"peers_total": float64(5),
|
||||
"peers_p2p": float64(2),
|
||||
"peers_relayed": float64(1),
|
||||
"peers_degraded": float64(0),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 1h30m")
|
||||
assert.Contains(t, out, "yes") // management_connected
|
||||
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
|
||||
assert.Contains(t, out, "a.example.com")
|
||||
assert.Contains(t, out, "c.example.com")
|
||||
assert.Contains(t, out, "acc-1")
|
||||
}
|
||||
|
||||
func TestPrintHealth_Minimal(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "5m",
|
||||
"management_connected": false,
|
||||
"all_clients_healthy": false,
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 5m")
|
||||
assert.NotContains(t, out, "Certificates")
|
||||
assert.NotContains(t, out, "ACCOUNT ID")
|
||||
}
|
||||
@@ -323,7 +323,7 @@ func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler
|
||||
if metricsHandler != nil {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", metricsHandler)
|
||||
mux.Handle("/", checker.Handler())
|
||||
mux.Handle("/", handler)
|
||||
handler = mux
|
||||
}
|
||||
|
||||
|
||||
@@ -404,3 +404,70 @@ func TestChecker_Handler_Full(t *testing.T) {
|
||||
// Clients may be empty map when no clients exist.
|
||||
assert.Empty(t, resp.Clients)
|
||||
}
|
||||
|
||||
func TestChecker_SetShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
|
||||
|
||||
checker.SetShuttingDown()
|
||||
|
||||
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetShuttingDown()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
}
|
||||
|
||||
func TestNewServer_WithMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("metrics"))
|
||||
})
|
||||
|
||||
srv := NewServer(":0", checker, nil, metricsHandler)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
// Verify health endpoint still works through the mux.
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify metrics endpoint is mounted.
|
||||
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "metrics", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
srv := NewServer(":0", checker, nil, nil)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
@@ -320,7 +320,8 @@ func getRequestID(r *http.Request) string {
|
||||
// status code, and component status based on the error type.
|
||||
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
case errors.Is(err, context.DeadlineExceeded),
|
||||
isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
@@ -356,12 +357,6 @@ func classifyProxyError(err error) (title, message string, code int, status web.
|
||||
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
return "Connection Error",
|
||||
|
||||
@@ -38,6 +38,11 @@ type domainInfo struct {
|
||||
serviceID string
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
@@ -114,6 +119,30 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
@@ -121,8 +150,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("generate wireguard private key: %w", err)
|
||||
return nil, fmt.Errorf("generate wireguard private key: %w", err)
|
||||
}
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
@@ -132,7 +160,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
// Authenticate with management using the one-time token and send public key
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
AccountId: string(accountID),
|
||||
@@ -141,16 +168,14 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
Cluster: n.proxyAddr,
|
||||
})
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
}
|
||||
if resp != nil && !resp.GetSuccess() {
|
||||
n.clientsMux.Unlock()
|
||||
errMsg := "unknown error"
|
||||
if resp.ErrorMessage != nil {
|
||||
errMsg = *resp.ErrorMessage
|
||||
}
|
||||
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -176,14 +201,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
WireguardPort: &n.wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
entry = &clientEntry{
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
@@ -196,75 +220,53 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
|
||||
@@ -3,6 +3,7 @@ package roundtrip
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -20,6 +21,31 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) calls() []statusCall {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]statusCall{}, m.statuses...)
|
||||
}
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
@@ -253,3 +279,50 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
235
proxy/server.go
235
proxy/server.go
@@ -157,41 +157,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
// The very first thing to do should be to connect to the Management server.
|
||||
// Without this connection, the Proxy cannot do anything.
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
mgmtConn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
// Simple TLS check using management URL.
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
@@ -205,54 +173,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
// Default to TLS-ALPN-01 challenge if not specified
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
// Only start HTTP server for HTTP-01 challenge type
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
} else {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
@@ -356,6 +279,92 @@ const (
|
||||
shutdownServiceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func (s *Server) dialManagement() (*grpc.ClientConn, error) {
|
||||
mgmtURL, err := url.Parse(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse management address: %w", err)
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
||||
if mgmtURL.Scheme == "https" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
// Fall back to embedded CAs if no OS-provided ones are available.
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
creds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
conn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create management connection: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if !s.GenerateACMECertificates {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
if s.ACMEChallengeType == "" {
|
||||
s.ACMEChallengeType = "tls-alpn-01"
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
|
||||
// ServerName needs to be set to allow for ACME to work correctly
|
||||
// when using CNAME URLs to access the proxy.
|
||||
tlsConfig.ServerName = s.ProxyURL
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"ServerName": s.ProxyURL,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificate manager configured")
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
||||
// readiness probe as failing, waits for load balancer propagation, drains
|
||||
// in-flight connections, and then stops all background services.
|
||||
@@ -508,36 +517,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
// Process msg updates sequentially to avoid conflict, so block
|
||||
// additional receiving until this processing is completed.
|
||||
for _, mapping := range msg.GetMapping() {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
@@ -551,6 +531,37 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"path": mapping.GetPath(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Errorf("failed to update mapping: %v", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
d := domain.Domain(mapping.GetDomain())
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
|
||||
Reference in New Issue
Block a user