diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go index c9e8cc278..6c0541efc 100644 --- a/relay/auth/hmac/store.go +++ b/relay/auth/hmac/store.go @@ -4,20 +4,20 @@ import ( "sync" ) -// Store is a simple in-memory store for token +// TokenStore is a simple in-memory store for token // With this can update the token in thread safe way -type Store struct { +type TokenStore struct { mu sync.Mutex token Token } -func (a *Store) UpdateToken(token Token) { +func (a *TokenStore) UpdateToken(token Token) { a.mu.Lock() defer a.mu.Unlock() a.token = token } -func (a *Store) Token() ([]byte, error) { +func (a *TokenStore) Token() ([]byte, error) { a.mu.Lock() defer a.mu.Unlock() return marshalToken(a.token) diff --git a/relay/client/client.go b/relay/client/client.go index 1b2bf27c7..9b0539b92 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -96,11 +96,11 @@ func (cc *connContainer) close() { // the client can be reused by calling Connect again. When the client is closed, all connections are closed too. // While the Connect is in progress, the OpenConn function will block until the connection is established. type Client struct { - log *log.Entry - parentCtx context.Context - connectionURL string - authStore *auth.Store - hashedID []byte + log *log.Entry + parentCtx context.Context + connectionURL string + authTokenStore *auth.TokenStore + hashedID []byte bufPool *sync.Pool @@ -117,14 +117,14 @@ type Client struct { } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client { +func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), - parentCtx: ctx, - connectionURL: serverURL, - authStore: authStore, - hashedID: hashedID, + log: log.WithField("client_id", hashedStringId), + parentCtx: ctx, + connectionURL: serverURL, + authTokenStore: authTokenStore, + hashedID: hashedID, bufPool: &sync.Pool{ New: func() any { buf := make([]byte, bufferSize) @@ -237,7 +237,7 @@ func (c *Client) connect() error { } func (c *Client) handShake() error { - t, err := c.authStore.Token() + t, err := c.authTokenStore.Token() if err != nil { return err } diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 7b1ee5c62..739c745fb 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -9,11 +9,20 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/relay/server" ) +var ( + av = &auth.AllowAllAuth{} + hmacTokenStore = &hmac.TokenStore{} + serverListenAddr = "localhost:1234" + serverURL = "rel://localhost:1234" +) + func TestMain(m *testing.M) { _ = util.InitLog("trace", "console") code := m.Run() @@ -23,8 +32,8 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -39,21 +48,21 @@ func TestClient(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() - clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder") + clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") err = clientPlaceHolder.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() - clientBob := NewClient(ctx, srvCfg.Address, "bob") + clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -91,8 +100,8 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -100,7 +109,7 @@ func TestRegistration(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { _ = srv.Close() @@ -140,7 +149,7 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice") + clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") err = clientAlice.Connect() if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -156,8 +165,8 @@ func TestEcho(t *testing.T) { ctx := context.Background() idAlice := "alice" idBob := "bob" - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -172,7 +181,7 @@ func TestEcho(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, idAlice) + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -184,7 +193,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(ctx, srvCfg.Address, idBob) + clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -236,8 +245,8 @@ func TestEcho(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -253,7 +262,7 @@ func TestBindToUnavailabePeer(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -273,8 +282,8 @@ func TestBindToUnavailabePeer(t *testing.T) { func TestBindReconnect(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -290,7 +299,7 @@ func TestBindReconnect(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -301,7 +310,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to bind channel: %s", err) } - clientBob := NewClient(ctx, srvCfg.Address, "bob") + clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") err = clientBob.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -318,7 +327,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(ctx, srvCfg.Address, "alice") + clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") err = clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -355,8 +364,8 @@ func TestBindReconnect(t *testing.T) { func TestCloseConn(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -372,7 +381,7 @@ func TestCloseConn(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -403,8 +412,8 @@ func TestCloseConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -419,7 +428,7 @@ func TestCloseRelayConn(t *testing.T) { } }() - clientAlice := NewClient(ctx, srvCfg.Address, "alice") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -446,8 +455,8 @@ func TestCloseRelayConn(t *testing.T) { func TestCloseByServer(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv1 := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv1 := server.NewServer(serverURL, false, av) go func() { err := srv1.Listen(srvCfg) if err != nil { @@ -457,7 +466,7 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, srvCfg.Address, idAlice) + relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) @@ -489,8 +498,8 @@ func TestCloseByServer(t *testing.T) { func TestCloseByClient(t *testing.T) { ctx := context.Background() - srvCfg := server.ListenerConfig{Address: "localhost:1234"} - srv := server.NewServer(srvCfg.Address, false) + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -500,7 +509,7 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, srvCfg.Address, idAlice) + relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) diff --git a/relay/client/manager.go b/relay/client/manager.go index 0e552754e..d1a26d705 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -40,7 +40,7 @@ type Manager struct { ctx context.Context serverURL string peerID string - tokenStore *relayAuth.Store + tokenStore *relayAuth.TokenStore relayClient *Client reconnectGuard *Guard @@ -57,7 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager { ctx: ctx, serverURL: serverURL, peerID: peerID, - tokenStore: &relayAuth.Store{}, + tokenStore: &relayAuth.TokenStore{}, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]map[*func()]struct{}), } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 3192c1a09..be0b9ad99 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -16,7 +16,7 @@ func TestForeignConn(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer(srvCfg1.Address, false) + srv1 := server.NewServer(srvCfg1.Address, false, av) go func() { err := srv1.Listen(srvCfg1) if err != nil { @@ -34,7 +34,7 @@ func TestForeignConn(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer(srvCfg2.Address, false) + srv2 := server.NewServer(srvCfg2.Address, false, av) go func() { err := srv2.Listen(srvCfg2) if err != nil { @@ -107,7 +107,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer(srvCfg1.Address, false) + srv1 := server.NewServer(srvCfg1.Address, false, av) go func() { err := srv1.Listen(srvCfg1) if err != nil { @@ -125,7 +125,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer(srvCfg2.Address, false) + srv2 := server.NewServer(srvCfg2.Address, false, av) go func() { err := srv2.Listen(srvCfg2) if err != nil { @@ -164,7 +164,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer(srvCfg1.Address, false) + srv1 := server.NewServer(srvCfg1.Address, false, av) go func() { t.Log("binding server 1.") err := srv1.Listen(srvCfg1) @@ -185,7 +185,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer(srvCfg2.Address, false) + srv2 := server.NewServer(srvCfg2.Address, false, av) go func() { t.Log("binding server 2.") err := srv2.Listen(srvCfg2) @@ -237,7 +237,7 @@ func TestAutoReconnect(t *testing.T) { srvCfg := server.ListenerConfig{ Address: "localhost:1234", } - srv := server.NewServer(srvCfg.Address, false) + srv := server.NewServer(srvCfg.Address, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { diff --git a/relay/cmd/main.go b/relay/cmd/main.go index 7315bd005..5e454b000 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -6,11 +6,13 @@ import ( "os" "os/signal" "syscall" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/netbirdio/netbird/encryption" + auth "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/util" ) @@ -82,7 +84,9 @@ func execute(cmd *cobra.Command, args []string) { } tlsSupport := srvListenerCfg.TLSConfig != nil - srv := server.NewServer(exposedAddress, tlsSupport, authSecret) + + authenticator := auth.NewTimedHMACValidator(authSecret, 24*time.Hour) + srv := server.NewServer(exposedAddress, tlsSupport, authenticator) log.Infof("server will be available on: %s", srv.InstanceURL()) err := srv.Listen(srvListenerCfg) if err != nil { diff --git a/relay/server/server.go b/relay/server/server.go index 6ab994a6a..cac884372 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -9,7 +9,7 @@ import ( log "github.com/sirupsen/logrus" - auth "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/udp" "github.com/netbirdio/netbird/relay/server/listener/ws" @@ -26,12 +26,13 @@ type Server struct { wSListener listener.Listener } -func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server { +func NewServer(exposedAddress string, tlsSupport bool, authValidator auth.Validator) *Server { return &Server{ relay: NewRelay( exposedAddress, tlsSupport, - auth.NewTimedHMACValidator(authSecret, 24*time.Hour)), + authValidator, + ), } }