diff --git a/relay/client/client.go b/relay/client/client.go index 1160d1c9e..7e5379ba5 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -14,8 +14,6 @@ import ( "github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" - "github.com/netbirdio/netbird/relay/messages/address" - auth2 "github.com/netbirdio/netbird/relay/messages/auth" ) const ( @@ -240,31 +238,21 @@ func (c *Client) connect() error { } func (c *Client) handShake() error { - authMsg := &auth2.Msg{ - AuthAlgorithm: auth2.AlgoHMACSHA256, - AdditionalData: c.authTokenStore.TokenBinary(), - } - - authData, err := authMsg.Marshal() + msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { - return fmt.Errorf("marshal auth message: %w", err) - } - - msg, err := messages.MarshalHelloMsg(c.hashedID, authData) - if err != nil { - log.Errorf("failed to marshal hello message: %s", err) + log.Errorf("failed to marshal auth message: %s", err) return err } _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to send hello message: %s", err) + log.Errorf("failed to send auth message: %s", err) return err } buf := make([]byte, messages.MaxHandshakeSize) n, err := c.readWithTimeout(buf) if err != nil { - log.Errorf("failed to read hello response: %s", err) + log.Errorf("failed to read auth response: %s", err) return err } @@ -279,23 +267,18 @@ func (c *Client) handShake() error { return err } - if msgType != messages.MsgTypeHelloResponse { + if msgType != messages.MsgTypeAuthResponse { log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } - additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n]) + addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) if err != nil { return err } - addr, err := address.Unmarshal(additionalData) - if err != nil { - return fmt.Errorf("unmarshal address: %w", err) - } - c.muInstanceURL.Lock() - c.instanceURL = &RelayAddr{addr: addr.URL} + c.instanceURL = &RelayAddr{addr: addr} c.muInstanceURL.Unlock() return nil } diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go index 829206294..707e73e55 100644 --- a/relay/messages/address/address.go +++ b/relay/messages/address/address.go @@ -1,3 +1,4 @@ +// Deprecated: This package is deprecated and will be removed in a future release. package address import ( @@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) { } return buf.Bytes(), nil } - -func Unmarshal(data []byte) (*Address, error) { - var addr Address - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - if err := dec.Decode(&addr); err != nil { - return nil, fmt.Errorf("decode Address: %w", err) - } - return &addr, nil -} diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go index 8230bccf2..9c2511f2f 100644 --- a/relay/messages/auth/auth.go +++ b/relay/messages/auth/auth.go @@ -1,3 +1,4 @@ +// Deprecated: This package is deprecated and will be removed in a future release. package auth import ( @@ -30,15 +31,6 @@ type Msg struct { AdditionalData []byte } -func (msg *Msg) Marshal() ([]byte, error) { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - if err := enc.Encode(msg); err != nil { - return nil, fmt.Errorf("encode Msg: %w", err) - } - return buf.Bytes(), nil -} - func UnmarshalMsg(data []byte) (*Msg, error) { var msg *Msg diff --git a/relay/messages/message.go b/relay/messages/message.go index cfcac3f72..f2c52ad4e 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -7,12 +7,16 @@ import ( ) const ( - MsgTypeUnknown MsgType = 0 - MsgTypeHello MsgType = 1 + MsgTypeUnknown MsgType = 0 + // Deprecated: Use MsgTypeAuth instead. + MsgTypeHello MsgType = 1 + // Deprecated: Use MsgTypeAuthResponse instead. MsgTypeHelloResponse MsgType = 2 MsgTypeTransport MsgType = 3 MsgTypeClose MsgType = 4 MsgTypeHealthCheck MsgType = 5 + MsgTypeAuth = 6 + MsgTypeAuthResponse = 7 SizeOfVersionByte = 1 SizeOfMsgType = 1 @@ -47,6 +51,10 @@ func (m MsgType) String() string { return "hello" case MsgTypeHelloResponse: return "hello response" + case MsgTypeAuth: + return "auth" + case MsgTypeAuthResponse: + return "auth response" case MsgTypeTransport: return "transport" case MsgTypeClose: @@ -58,10 +66,6 @@ func (m MsgType) String() string { } } -type HelloResponse struct { - InstanceAddress string -} - // ValidateVersion checks if the given version is supported by the protocol func ValidateVersion(msg []byte) (int, error) { if len(msg) < SizeOfVersionByte { @@ -84,6 +88,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHello, + MsgTypeAuth, MsgTypeTransport, MsgTypeClose, MsgTypeHealthCheck: @@ -103,6 +108,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHelloResponse, + MsgTypeAuthResponse, MsgTypeTransport, MsgTypeClose, MsgTypeHealthCheck: @@ -112,6 +118,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { } } +// Deprecated: Use MarshalAuthMsg instead. // MarshalHelloMsg initial hello message // The Hello message is the first message sent by a client after establishing a connection with the Relay server. This // message is used to authenticate the client with the server. The authentication is done using an HMAC method. @@ -135,6 +142,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { return msg, nil } +// Deprecated: Use UnmarshalAuthMsg instead. // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { @@ -148,6 +156,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil } +// Deprecated: Use MarshalAuthResponse instead. // MarshalHelloResponse creates a response message to the hello message. // In case of success connection the server response with a Hello Response message. This message contains the server's // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay @@ -163,6 +172,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) { return msg, nil } +// Deprecated: Use UnmarshalAuthResponse instead. // UnmarshalHelloResponse extracts the additional data from the hello response message. func UnmarshalHelloResponse(msg []byte) ([]byte, error) { if len(msg) < headerSizeHelloResp { @@ -171,6 +181,65 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) { return msg, nil } +// MarshalAuthMsg initial authentication message +// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This +// message is used to authenticate the client with the server. The authentication is done using an HMAC method. +// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will +// close the network connection without any response. +func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) + } + + msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(authPayload)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeHello) + + copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + + msg = append(msg, peerID...) + msg = append(msg, authPayload...) + + return msg, nil +} + +// UnmarshalAuthMsg extracts peerID and the auth payload from the message +func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { + if len(msg) < headerSizeHello { + return nil, nil, ErrInvalidMessageLength + } + if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + return nil, nil, errors.New("invalid magic header") + } + + return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil +} + +// MarshalAuthResponse creates a response message to the auth. +// In case of success connection the server response with a AuthResponse message. This message contains the server's +// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay +// servers. +func MarshalAuthResponse(address string) ([]byte, error) { + ab := []byte(address) + msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(ab)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeHelloResponse) + + msg = append(msg, ab...) + + return msg, nil +} + +// UnmarshalAuthResponse it is a confirmation message to auth success +func UnmarshalAuthResponse(msg []byte) (string, error) { + if len(msg) < headerSizeHelloResp+1 { + return "", ErrInvalidMessageLength + } + return string(msg), nil +} + // MarshalCloseMsg creates a close message. // The close message is used to close the connection gracefully between the client and the server. The server and the // client can send this message. After receiving this message, the server or client will close the connection. diff --git a/relay/server/relay.go b/relay/server/relay.go index 6d88cbbb2..4dc262904 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -21,9 +21,10 @@ import ( // Relay represents the relay server type Relay struct { - metrics *metrics.Metrics - metricsCancel context.CancelFunc - validator auth.Validator + metrics *metrics.Metrics + metricsCancel context.CancelFunc + validator auth.Validator + validatorDummy auth.Validator // todo: this is just a dummy variable. Replace it with the proper validator store *Store instanceURL string @@ -168,14 +169,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) { return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) } - if msgType != messages.MsgTypeHello { - return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr()) + var ( + responseMsg []byte + peerID []byte + ) + switch msgType { + case messages.MsgTypeHello: + responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + case messages.MsgTypeAuth: + responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) + } + if err != nil { + return nil, err } - peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n]) + _, err = conn.Write(responseMsg) + if err != nil { + return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) + } + + return peerID, nil +} + +func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, error) { + peerID, authData, err := messages.UnmarshalHelloMsg(buf) if err != nil { return nil, fmt.Errorf("unmarshal hello message: %w", err) } + log.Warnf("peer is using depracated initial message type: %s (%s)", peerID, remoteAddr) authMsg, err := authmsg.UnmarshalMsg(authData) if err != nil { @@ -183,24 +206,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) { } if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil { - return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err) + return nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) } addr := &address.Address{URL: r.instanceURL} addrData, err := addr.Marshal() if err != nil { - return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err) + return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) } - msg, err := messages.MarshalHelloResponse(addrData) + responseMsg, err := messages.MarshalHelloResponse(addrData) if err != nil { - return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err) + return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) } - - _, err = conn.Write(msg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil + return responseMsg, nil +} + +func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, error) { + peerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, fmt.Errorf("unmarshal hello message: %w", err) + } + + // todo use the proper validator + if err := r.validatorDummy.Validate(sha256.New, authPayload); err != nil { + return nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) + } + + responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) + if err != nil { + return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) + } + return responseMsg, nil }