[relay] Use instanceURL instead of Exposed address. (#4905)

Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields.
This commit is contained in:
Zoltan Papp
2025-12-03 18:42:53 +01:00
committed by GitHub
parent 27dd97c9c4
commit d2e48d4f5e
7 changed files with 47 additions and 45 deletions

View File

@@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
}
log.Infof("server will be available on: %s", srv.InstanceURL())
instanceURL := srv.InstanceURL()
log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -6,13 +6,14 @@ import (
"errors"
"net"
"net/http"
"strings"
"net/url"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server"
)
const (
@@ -26,7 +27,7 @@ const (
type ServiceChecker interface {
ListenerProtocols() []protocol.Protocol
ExposedAddress() string
InstanceURL() url.URL
}
type HealthStatus struct {
@@ -134,7 +135,7 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
}
status.Listeners = listeners
if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") {
if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS {
status.CertificateValid = false
}
@@ -156,14 +157,9 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
}
func (s *Server) validateConnection(ctx context.Context) bool {
exposedAddress := s.config.ServiceChecker.ExposedAddress()
if exposedAddress == "" {
log.Error("exposed address is empty, cannot validate certificate")
return false
}
if err := dialWS(ctx, exposedAddress); err != nil {
log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err)
addr := s.config.ServiceChecker.InstanceURL()
if err := dialWS(ctx, addr); err != nil {
log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err)
return false
}

View File

@@ -3,22 +3,22 @@ package healthcheck
import (
"context"
"fmt"
"strings"
"net/url"
"github.com/coder/websocket"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay"
)
func dialWS(ctx context.Context, address string) error {
addressSplit := strings.Split(address, "/")
func dialWS(ctx context.Context, address url.URL) error {
scheme := "ws"
if addressSplit[0] == "rels:" {
if address.Scheme == server.SchemeRELS {
scheme = "wss"
}
url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath)
wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath)
conn, resp, err := websocket.Dial(ctx, url, nil)
conn, resp, err := websocket.Dial(ctx, wsURL, nil)
if resp != nil {
defer func() {
if resp.Body != nil {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/url"
"sync"
"time"
@@ -22,7 +23,7 @@ type Config struct {
TLSSupport bool
AuthValidator Validator
instanceURL string
instanceURL url.URL
}
func (c *Config) validate() error {
@@ -37,7 +38,7 @@ func (c *Config) validate() error {
if err != nil {
return fmt.Errorf("invalid url: %v", err)
}
c.instanceURL = instanceURL
c.instanceURL = *instanceURL
if c.AuthValidator == nil {
return fmt.Errorf("auth validator is required")
@@ -53,7 +54,7 @@ type Relay struct {
store *store.Store
notifier *store.PeerNotifier
instanceURL string
instanceURL url.URL
exposedAddress string
preparedMsg *preparedMsg
@@ -97,7 +98,7 @@ func NewRelay(config Config) (*Relay, error) {
notifier: store.NewPeerNotifier(),
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
r.preparedMsg, err = newPreparedMsg(r.instanceURL.String())
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
@@ -177,11 +178,6 @@ func (r *Relay) Shutdown(ctx context.Context) {
}
// InstanceURL returns the instance URL of the relay server
func (r *Relay) InstanceURL() string {
func (r *Relay) InstanceURL() url.URL {
return r.instanceURL
}
// ExposedAddress returns the exposed address (domain:port) where clients connect
func (r *Relay) ExposedAddress() string {
return r.exposedAddress
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"net/url"
"sync"
"github.com/hashicorp/go-multierror"
@@ -39,7 +40,7 @@ type Server struct {
//
// config: A Config struct containing the necessary configuration:
// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required.
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
// - AuthValidator: A Validator used to authenticate peers. Required.
//
@@ -119,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error {
return nberrors.FormatErrorOrNil(multiErr)
}
// InstanceURL returns the instance URL of the relay server.
func (r *Server) InstanceURL() string {
return r.relay.instanceURL
}
func (r *Server) ListenerProtocols() []protocol.Protocol {
result := make([]protocol.Protocol, 0)
@@ -135,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol {
return result
}
func (r *Server) ExposedAddress() string {
return r.relay.ExposedAddress()
func (r *Server) InstanceURL() url.URL {
return r.relay.InstanceURL()
}

View File

@@ -6,9 +6,14 @@ import (
"strings"
)
const (
SchemeREL = "rel"
SchemeRELS = "rels"
)
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
@@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
return nil, fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS {
return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
// Validate scheme matches TLS configuration
if tlsSupported && parsedURL.Scheme == SchemeREL {
return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL)
}
return parsedURL, nil
}

View File

@@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) {
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
{"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
{"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true},
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
@@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) {
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
if url != tt.expectedURL {
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
if !tt.expectError && url != nil && url.String() != tt.expectedURL {
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String())
}
if tt.expectError && url != nil {
t.Errorf("expected nil URL on error, got: %s", url.String())
}
})
}