diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go index 8a6648a3a..ef8276b59 100644 --- a/management/server/token_mgr.go +++ b/management/server/token_mgr.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha1" "crypto/sha256" + "encoding/base64" "fmt" "sync" "time" @@ -12,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" + authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" ) const defaultDuration = 12 * time.Hour @@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct { turnCfg *TURNConfig relayCfg *Relay turnHmacToken *auth.TimedHMAC - relayHmacToken *auth.TimedHMAC + relayHmacToken *authv2.Generator updateManager *PeersUpdateManager turnCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{} @@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg * duration = defaultDuration } - mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration) + hashedSecret := sha256.Sum256([]byte(relayCfg.Secret)) + var err error + if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil { + log.Errorf("failed to create relay token generator: %s", err) + } } return mgr @@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) { } turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) if err != nil { - return nil, fmt.Errorf("failed to generate TURN token: %s", err) + return nil, fmt.Errorf("generate TURN token: %s", err) } return (*Token)(turnToken), nil } @@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) { if m.relayHmacToken == nil { return nil, fmt.Errorf("relay configuration is not set") } - relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) + relayToken, err := m.relayHmacToken.GenerateToken() if err != nil { - return nil, fmt.Errorf("failed to generate relay token: %s", err) + return nil, fmt.Errorf("generate relay token: %s", err) } - return (*Token)(relayToken), nil + + return &Token{ + Payload: string(relayToken.Payload), + Signature: base64.StdEncoding.EncodeToString(relayToken.Signature), + }, nil } func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) { @@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) { - relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) + relayToken, err := m.relayHmacToken.GenerateToken() if err != nil { log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err) return @@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe WiretrusteeConfig: &proto.WiretrusteeConfig{ Relay: &proto.RelayConfig{ Urls: m.relayCfg.Addresses, - TokenPayload: relayToken.Payload, - TokenSignature: relayToken.Signature, + TokenPayload: string(relayToken.Payload), + TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature), }, // omit Turns to avoid updates there }, diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index d59fd3a3f..3e63346c2 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { t.Errorf("expected generated relay signature not to be empty, got empty") } - validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret)) + hashedSecret := sha256.Sum256([]byte(secret)) + validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:]) } func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { diff --git a/relay/auth/allow/allow_all.go b/relay/auth/allow/allow_all.go index 92845818b..2d30c59c9 100644 --- a/relay/auth/allow/allow_all.go +++ b/relay/auth/allow/allow_all.go @@ -1,12 +1,14 @@ package allow -import "hash" - // Auth is a Validator that allows all connections. // Used this for testing purposes only. type Auth struct { } -func (a *Auth) Validate(func() hash.Hash, any) error { +func (a *Auth) Validate(any) error { + return nil +} + +func (a *Auth) ValidateHelloMsgType(any) error { return nil } diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go index 36c195a7b..169b8d6b0 100644 --- a/relay/auth/hmac/store.go +++ b/relay/auth/hmac/store.go @@ -1,9 +1,11 @@ package hmac import ( + "encoding/base64" + "fmt" "sync" - log "github.com/sirupsen/logrus" + v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" ) // TokenStore is a simple in-memory store for token @@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error { return nil } - t, err := marshalToken(*token) + sig, err := base64.StdEncoding.DecodeString(token.Signature) if err != nil { - log.Debugf("failed to marshal token: %s", err) - return err + return fmt.Errorf("decode signature: %w", err) } - a.token = t + + tok := v2.Token{ + AuthAlgo: v2.AuthAlgoHMACSHA256, + Signature: sig, + Payload: []byte(token.Payload), + } + + a.token = tok.Marshal() return nil } diff --git a/relay/auth/hmac/token.go b/relay/auth/hmac/token.go index e2e62b84e..581b1d6fd 100644 --- a/relay/auth/hmac/token.go +++ b/relay/auth/hmac/token.go @@ -18,17 +18,6 @@ type Token struct { Signature string } -func marshalToken(token Token) ([]byte, error) { - var buffer bytes.Buffer - encoder := gob.NewEncoder(&buffer) - err := encoder.Encode(token) - if err != nil { - log.Debugf("failed to marshal token: %s", err) - return nil, fmt.Errorf("failed to marshal token: %w", err) - } - return buffer.Bytes(), nil -} - func unmarshalToken(payload []byte) (Token, error) { var creds Token buffer := bytes.NewBuffer(payload) diff --git a/relay/auth/hmac/v2/algo.go b/relay/auth/hmac/v2/algo.go new file mode 100644 index 000000000..c379c2bd7 --- /dev/null +++ b/relay/auth/hmac/v2/algo.go @@ -0,0 +1,40 @@ +package v2 + +import ( + "crypto/sha256" + "hash" +) + +const ( + AuthAlgoUnknown AuthAlgo = iota + AuthAlgoHMACSHA256 +) + +type AuthAlgo uint8 + +func (a AuthAlgo) String() string { + switch a { + case AuthAlgoHMACSHA256: + return "HMAC-SHA256" + default: + return "Unknown" + } +} + +func (a AuthAlgo) New() func() hash.Hash { + switch a { + case AuthAlgoHMACSHA256: + return sha256.New + default: + return nil + } +} + +func (a AuthAlgo) Size() int { + switch a { + case AuthAlgoHMACSHA256: + return sha256.Size + default: + return 0 + } +} diff --git a/relay/auth/hmac/v2/generator.go b/relay/auth/hmac/v2/generator.go new file mode 100644 index 000000000..827532730 --- /dev/null +++ b/relay/auth/hmac/v2/generator.go @@ -0,0 +1,45 @@ +package v2 + +import ( + "crypto/hmac" + "fmt" + "hash" + "strconv" + "time" +) + +type Generator struct { + algo func() hash.Hash + algoType AuthAlgo + secret []byte + timeToLive time.Duration +} + +func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) { + algoFunc := algo.New() + if algoFunc == nil { + return nil, fmt.Errorf("unsupported auth algorithm: %s", algo) + } + return &Generator{ + algo: algoFunc, + algoType: algo, + secret: secret, + timeToLive: timeToLive, + }, nil +} + +func (g *Generator) GenerateToken() (*Token, error) { + expirationTime := time.Now().Add(g.timeToLive).Unix() + + payload := []byte(strconv.FormatInt(expirationTime, 10)) + + h := hmac.New(g.algo, g.secret) + h.Write(payload) + signature := h.Sum(nil) + + return &Token{ + AuthAlgo: g.algoType, + Signature: signature, + Payload: payload, + }, nil +} diff --git a/relay/auth/hmac/v2/hmac_test.go b/relay/auth/hmac/v2/hmac_test.go new file mode 100644 index 000000000..40336363f --- /dev/null +++ b/relay/auth/hmac/v2/hmac_test.go @@ -0,0 +1,110 @@ +package v2 + +import ( + "strconv" + "testing" + "time" +) + +func TestGenerateCredentials(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(token.Payload) == 0 { + t.Fatalf("expected non-empty payload") + } + + _, err = strconv.ParseInt(string(token.Payload), 10, 64) + if err != nil { + t.Fatalf("expected payload to be a valid unix timestamp, got %v", err) + } +} + +func TestValidateCredentials(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err != nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestInvalidSignature(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestExpired(t *testing.T) { + secret := "supersecret" + timeToLive := -1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestInvalidPayload(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected invalid token due to invalid payload") + } +} diff --git a/relay/auth/hmac/v2/token.go b/relay/auth/hmac/v2/token.go new file mode 100644 index 000000000..553ac01b9 --- /dev/null +++ b/relay/auth/hmac/v2/token.go @@ -0,0 +1,39 @@ +package v2 + +import "errors" + +type Token struct { + AuthAlgo AuthAlgo + Signature []byte + Payload []byte +} + +func (t *Token) Marshal() []byte { + size := 1 + len(t.Signature) + len(t.Payload) + + buf := make([]byte, size) + + buf[0] = byte(t.AuthAlgo) + copy(buf[1:], t.Signature) + copy(buf[1+len(t.Signature):], t.Payload) + + return buf +} + +func UnmarshalToken(data []byte) (*Token, error) { + if len(data) == 0 { + return nil, errors.New("invalid token data") + } + + algo := AuthAlgo(data[0]) + sigSize := algo.Size() + if len(data) < 1+sigSize { + return nil, errors.New("invalid token data: insufficient length") + } + + return &Token{ + AuthAlgo: algo, + Signature: data[1 : 1+sigSize], + Payload: data[1+sigSize:], + }, nil +} diff --git a/relay/auth/hmac/v2/validator.go b/relay/auth/hmac/v2/validator.go new file mode 100644 index 000000000..7f448dd5f --- /dev/null +++ b/relay/auth/hmac/v2/validator.go @@ -0,0 +1,59 @@ +package v2 + +import ( + "crypto/hmac" + "errors" + "fmt" + "strconv" + "time" +) + +const minLengthUnixTimestamp = 10 + +type Validator struct { + secret []byte +} + +func NewValidator(secret []byte) *Validator { + return &Validator{secret: secret} +} + +func (v *Validator) Validate(data any) error { + d, ok := data.([]byte) + if !ok { + return fmt.Errorf("invalid data type") + } + + token, err := UnmarshalToken(d) + if err != nil { + return fmt.Errorf("unmarshal token: %w", err) + } + + if len(token.Payload) < minLengthUnixTimestamp { + return errors.New("invalid payload: insufficient length") + } + + hashFunc := token.AuthAlgo.New() + if hashFunc == nil { + return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo) + } + + h := hmac.New(hashFunc, v.secret) + h.Write(token.Payload) + expectedMAC := h.Sum(nil) + + if !hmac.Equal(token.Signature, expectedMAC) { + return errors.New("invalid signature") + } + + timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64) + if err != nil { + return fmt.Errorf("invalid payload: %w", err) + } + + if time.Now().Unix() > timestamp { + return fmt.Errorf("expired token") + } + + return nil +} diff --git a/relay/auth/hmac/validator.go b/relay/auth/hmac/validator.go index 6ddd89c19..b0b7542be 100644 --- a/relay/auth/hmac/validator.go +++ b/relay/auth/hmac/validator.go @@ -1,8 +1,8 @@ package hmac import ( + "crypto/sha256" "fmt" - "hash" "time" log "github.com/sirupsen/logrus" @@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali } } -func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error { +func (a *TimedHMACValidator) Validate(credentials any) error { b, ok := credentials.([]byte) if !ok { return fmt.Errorf("invalid credentials type") @@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er log.Debugf("failed to unmarshal token: %s", err) return err } - return a.TimedHMAC.Validate(algo, c) + return a.TimedHMAC.Validate(sha256.New, c) } diff --git a/relay/auth/validator.go b/relay/auth/validator.go index 078811f3d..854efd5bb 100644 --- a/relay/auth/validator.go +++ b/relay/auth/validator.go @@ -1,8 +1,35 @@ package auth -import "hash" +import ( + "time" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" + authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" +) // Validator is an interface that defines the Validate method. type Validator interface { - Validate(func() hash.Hash, any) error + Validate(any) error + // Deprecated: Use Validate instead. + ValidateHelloMsgType(any) error +} + +type TimedHMACValidator struct { + authenticatorV2 *authv2.Validator + authenticator *auth.TimedHMACValidator +} + +func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator { + return &TimedHMACValidator{ + authenticatorV2: authv2.NewValidator(secret), + authenticator: auth.NewTimedHMACValidator(string(secret), duration), + } +} + +func (a *TimedHMACValidator) Validate(credentials any) error { + return a.authenticatorV2.Validate(credentials) +} + +func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error { + return a.authenticator.Validate(credentials) } diff --git a/relay/client/client.go b/relay/client/client.go index 1160d1c9e..6560c81e1 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -14,8 +14,6 @@ import ( "github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" - "github.com/netbirdio/netbird/relay/messages/address" - auth2 "github.com/netbirdio/netbird/relay/messages/auth" ) const ( @@ -240,31 +238,21 @@ func (c *Client) connect() error { } func (c *Client) handShake() error { - authMsg := &auth2.Msg{ - AuthAlgorithm: auth2.AlgoHMACSHA256, - AdditionalData: c.authTokenStore.TokenBinary(), - } - - authData, err := authMsg.Marshal() + msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { - return fmt.Errorf("marshal auth message: %w", err) - } - - msg, err := messages.MarshalHelloMsg(c.hashedID, authData) - if err != nil { - log.Errorf("failed to marshal hello message: %s", err) + log.Errorf("failed to marshal auth message: %s", err) return err } _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to send hello message: %s", err) + log.Errorf("failed to send auth message: %s", err) return err } - buf := make([]byte, messages.MaxHandshakeSize) + buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(buf) if err != nil { - log.Errorf("failed to read hello response: %s", err) + log.Errorf("failed to read auth response: %s", err) return err } @@ -279,23 +267,18 @@ func (c *Client) handShake() error { return err } - if msgType != messages.MsgTypeHelloResponse { + if msgType != messages.MsgTypeAuthResponse { log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } - additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n]) + addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) if err != nil { return err } - addr, err := address.Unmarshal(additionalData) - if err != nil { - return fmt.Errorf("unmarshal address: %w", err) - } - c.muInstanceURL.Lock() - c.instanceURL = &RelayAddr{addr: addr.URL} + c.instanceURL = &RelayAddr{addr: addr} c.muInstanceURL.Unlock() return nil } diff --git a/relay/cmd/root.go b/relay/cmd/root.go index 784b42c1a..dcc1465d0 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "crypto/sha256" "crypto/tls" "errors" "fmt" @@ -16,7 +17,7 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/encryption" - auth "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/util" @@ -139,7 +140,9 @@ func execute(cmd *cobra.Command, args []string) error { } srvListenerCfg.TLSConfig = tlsConfig - authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour) + hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) + authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) + srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) if err != nil { log.Debugf("failed to create relay server: %v", err) diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go index 829206294..707e73e55 100644 --- a/relay/messages/address/address.go +++ b/relay/messages/address/address.go @@ -1,3 +1,4 @@ +// Deprecated: This package is deprecated and will be removed in a future release. package address import ( @@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) { } return buf.Bytes(), nil } - -func Unmarshal(data []byte) (*Address, error) { - var addr Address - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - if err := dec.Decode(&addr); err != nil { - return nil, fmt.Errorf("decode Address: %w", err) - } - return &addr, nil -} diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go index 8230bccf2..9c2511f2f 100644 --- a/relay/messages/auth/auth.go +++ b/relay/messages/auth/auth.go @@ -1,3 +1,4 @@ +// Deprecated: This package is deprecated and will be removed in a future release. package auth import ( @@ -30,15 +31,6 @@ type Msg struct { AdditionalData []byte } -func (msg *Msg) Marshal() ([]byte, error) { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - if err := enc.Encode(msg); err != nil { - return nil, fmt.Errorf("encode Msg: %w", err) - } - return buf.Bytes(), nil -} - func UnmarshalMsg(data []byte) (*Msg, error) { var msg *Msg diff --git a/relay/messages/message.go b/relay/messages/message.go index cfcac3f72..39ca0aa90 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -7,12 +7,21 @@ import ( ) const ( - MsgTypeUnknown MsgType = 0 - MsgTypeHello MsgType = 1 + MaxHandshakeSize = 212 + MaxHandshakeRespSize = 8192 + + CurrentProtocolVersion = 1 + + MsgTypeUnknown MsgType = 0 + // Deprecated: Use MsgTypeAuth instead. + MsgTypeHello MsgType = 1 + // Deprecated: Use MsgTypeAuthResponse instead. MsgTypeHelloResponse MsgType = 2 MsgTypeTransport MsgType = 3 MsgTypeClose MsgType = 4 MsgTypeHealthCheck MsgType = 5 + MsgTypeAuth = 6 + MsgTypeAuthResponse = 7 SizeOfVersionByte = 1 SizeOfMsgType = 1 @@ -22,12 +31,12 @@ const ( sizeOfMagicByte = 4 headerSizeTransport = IDSize + headerSizeHello = sizeOfMagicByte + IDSize headerSizeHelloResp = 0 - MaxHandshakeSize = 8192 - - CurrentProtocolVersion = 1 + headerSizeAuth = sizeOfMagicByte + IDSize + headerSizeAuthResp = 0 ) var ( @@ -47,6 +56,10 @@ func (m MsgType) String() string { return "hello" case MsgTypeHelloResponse: return "hello response" + case MsgTypeAuth: + return "auth" + case MsgTypeAuthResponse: + return "auth response" case MsgTypeTransport: return "transport" case MsgTypeClose: @@ -58,10 +71,6 @@ func (m MsgType) String() string { } } -type HelloResponse struct { - InstanceAddress string -} - // ValidateVersion checks if the given version is supported by the protocol func ValidateVersion(msg []byte) (int, error) { if len(msg) < SizeOfVersionByte { @@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHello, + MsgTypeAuth, MsgTypeTransport, MsgTypeClose, MsgTypeHealthCheck: @@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHelloResponse, + MsgTypeAuthResponse, MsgTypeTransport, MsgTypeClose, MsgTypeHealthCheck: @@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { } } +// Deprecated: Use MarshalAuthMsg instead. // MarshalHelloMsg initial hello message // The Hello message is the first message sent by a client after establishing a connection with the Relay server. This // message is used to authenticate the client with the server. The authentication is done using an HMAC method. @@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { return msg, nil } +// Deprecated: Use UnmarshalAuthMsg instead. // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { @@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil } +// Deprecated: Use MarshalAuthResponse instead. // MarshalHelloResponse creates a response message to the hello message. // In case of success connection the server response with a Hello Response message. This message contains the server's // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay @@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) { return msg, nil } +// Deprecated: Use UnmarshalAuthResponse instead. // UnmarshalHelloResponse extracts the additional data from the hello response message. func UnmarshalHelloResponse(msg []byte) ([]byte, error) { if len(msg) < headerSizeHelloResp { @@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) { return msg, nil } +// MarshalAuthMsg initial authentication message +// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This +// message is used to authenticate the client with the server. The authentication is done using an HMAC method. +// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will +// close the network connection without any response. +func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) + } + + msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeAuth) + + copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + + msg = append(msg, peerID...) + msg = append(msg, authPayload...) + + return msg, nil +} + +// UnmarshalAuthMsg extracts peerID and the auth payload from the message +func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { + if len(msg) < headerSizeAuth { + return nil, nil, ErrInvalidMessageLength + } + if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + return nil, nil, errors.New("invalid magic header") + } + + return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil +} + +// MarshalAuthResponse creates a response message to the auth. +// In case of success connection the server response with a AuthResponse message. This message contains the server's +// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay +// servers. +func MarshalAuthResponse(address string) ([]byte, error) { + ab := []byte(address) + msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeAuthResponse) + + msg = append(msg, ab...) + + if len(msg) > MaxHandshakeRespSize { + return nil, fmt.Errorf("invalid message length: %d", len(msg)) + } + + return msg, nil +} + +// UnmarshalAuthResponse it is a confirmation message to auth success +func UnmarshalAuthResponse(msg []byte) (string, error) { + if len(msg) < headerSizeAuthResp+1 { + return "", ErrInvalidMessageLength + } + return string(msg), nil +} + // MarshalCloseMsg creates a close message. // The close message is used to close the connection gracefully between the client and the server. The server and the // client can send this message. After receiving this message, the server or client will close the connection. diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index a4e7d9fae..6e917da71 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) { } } +func TestMarshalAuthMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + bHello, err := MarshalAuthMsg(peerID, []byte{}) + if err != nil { + t.Fatalf("error: %v", err) + } + + receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:]) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(receivedPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, receivedPeerID) + } +} + func TestMarshalTransportMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") payload := []byte("payload") diff --git a/relay/server/relay.go b/relay/server/relay.go index 6d88cbbb2..76c01a697 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -2,7 +2,6 @@ package server import ( "context" - "crypto/sha256" "fmt" "net" "net/url" @@ -14,7 +13,9 @@ import ( "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) { return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) } - if msgType != messages.MsgTypeHello { - return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr()) + var ( + responseMsg []byte + peerID []byte + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + case messages.MsgTypeAuth: + peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) } - - peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n]) if err != nil { - return nil, fmt.Errorf("unmarshal hello message: %w", err) + return nil, err } - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil { - return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - msg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - _, err = conn.Write(msg) + _, err = conn.Write(responseMsg) if err != nil { return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) } return peerID, nil } + +func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) + } + + addr := &address.Address{URL: r.instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) + } + return rawPeerID, responseMsg, nil +} + +func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := r.validator.Validate(authPayload); err != nil { + return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) + } + + responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) + if err != nil { + return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) + } + + return rawPeerID, responseMsg, nil +}