mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[misc] Separate shared code dependencies (#4288)
* Separate shared code dependencies * Fix import * Test respective shared code * Update openapi ref * Fix test * Fix test path
This commit is contained in:
21
shared/relay/messages/address/address.go
Normal file
21
shared/relay/messages/address/address.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package address
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Address struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
func (addr *Address) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(addr); err != nil {
|
||||
return nil, fmt.Errorf("encode Address: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
43
shared/relay/messages/auth/auth.go
Normal file
43
shared/relay/messages/auth/auth.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Algorithm int
|
||||
|
||||
const (
|
||||
AlgoUnknown Algorithm = iota
|
||||
AlgoHMACSHA256
|
||||
AlgoHMACSHA512
|
||||
)
|
||||
|
||||
func (a Algorithm) String() string {
|
||||
switch a {
|
||||
case AlgoHMACSHA256:
|
||||
return "HMAC-SHA256"
|
||||
case AlgoHMACSHA512:
|
||||
return "HMAC-SHA512"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Msg struct {
|
||||
AuthAlgorithm Algorithm
|
||||
AdditionalData []byte
|
||||
}
|
||||
|
||||
func UnmarshalMsg(data []byte) (*Msg, error) {
|
||||
var msg *Msg
|
||||
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err := dec.Decode(&msg); err != nil {
|
||||
return nil, fmt.Errorf("decode Msg: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
5
shared/relay/messages/doc.go
Normal file
5
shared/relay/messages/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
/*
|
||||
Package messages provides the message types that are used to communicate between the relay and the client.
|
||||
This package is used to determine the type of message that is being sent and received between the relay and the client.
|
||||
*/
|
||||
package messages
|
||||
31
shared/relay/messages/id.go
Normal file
31
shared/relay/messages/id.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixLength = 4
|
||||
peerIDSize = prefixLength + sha256.Size
|
||||
)
|
||||
|
||||
var (
|
||||
prefix = []byte("sha-") // 4 bytes
|
||||
)
|
||||
|
||||
type PeerID [peerIDSize]byte
|
||||
|
||||
func (p PeerID) String() string {
|
||||
return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
|
||||
}
|
||||
|
||||
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
||||
func HashID(peerID string) PeerID {
|
||||
idHash := sha256.Sum256([]byte(peerID))
|
||||
var prefixedHash [peerIDSize]byte
|
||||
copy(prefixedHash[:prefixLength], prefix)
|
||||
copy(prefixedHash[prefixLength:], idHash[:])
|
||||
return prefixedHash
|
||||
}
|
||||
337
shared/relay/messages/message.go
Normal file
337
shared/relay/messages/message.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxHandshakeSize = 212
|
||||
MaxHandshakeRespSize = 8192
|
||||
MaxMessageSize = 8820
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
|
||||
MsgTypeUnknown MsgType = 0
|
||||
// Deprecated: Use MsgTypeAuth instead.
|
||||
MsgTypeHello = 1
|
||||
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||
MsgTypeHelloResponse = 2
|
||||
MsgTypeTransport = 3
|
||||
MsgTypeClose = 4
|
||||
MsgTypeHealthCheck = 5
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
// Peers state messages
|
||||
MsgTypeSubscribePeerState = 8
|
||||
MsgTypeUnsubscribePeerState = 9
|
||||
MsgTypePeersOnline = 10
|
||||
MsgTypePeersWentOffline = 11
|
||||
|
||||
// base size of the message
|
||||
sizeOfVersionByte = 1
|
||||
sizeOfMsgType = 1
|
||||
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
|
||||
|
||||
// auth message
|
||||
sizeOfMagicByte = 4
|
||||
headerSizeAuth = sizeOfMagicByte + peerIDSize
|
||||
offsetMagicByte = sizeOfProtoHeader
|
||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||
|
||||
// hello message
|
||||
headerSizeHello = sizeOfMagicByte + peerIDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
// transport
|
||||
headerSizeTransport = peerIDSize
|
||||
offsetTransportID = sizeOfProtoHeader
|
||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessageLength = errors.New("invalid message length")
|
||||
ErrUnsupportedVersion = errors.New("unsupported version")
|
||||
|
||||
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
||||
|
||||
healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
|
||||
)
|
||||
|
||||
type MsgType byte
|
||||
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgTypeHello:
|
||||
return "hello"
|
||||
case MsgTypeHelloResponse:
|
||||
return "hello response"
|
||||
case MsgTypeAuth:
|
||||
return "auth"
|
||||
case MsgTypeAuthResponse:
|
||||
return "auth response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgTypeClose:
|
||||
return "close"
|
||||
case MsgTypeHealthCheck:
|
||||
return "health check"
|
||||
case MsgTypeSubscribePeerState:
|
||||
return "subscribe peer state"
|
||||
case MsgTypeUnsubscribePeerState:
|
||||
return "unsubscribe peer state"
|
||||
case MsgTypePeersOnline:
|
||||
return "peers online"
|
||||
case MsgTypePeersWentOffline:
|
||||
return "peers went offline"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
if version != CurrentProtocolVersion {
|
||||
return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
MsgTypeAuth,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck,
|
||||
MsgTypeSubscribePeerState,
|
||||
MsgTypeUnsubscribePeerState:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
MsgTypeAuthResponse,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck,
|
||||
MsgTypePeersOnline,
|
||||
MsgTypePeersWentOffline:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthMsg instead.
|
||||
// MarshalHelloMsg initial hello message
|
||||
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID[:]...)
|
||||
msg = append(msg, additions...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
|
||||
|
||||
return &peerID, msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
// MarshalHelloResponse creates a response message to the hello message.
|
||||
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
|
||||
msg = append(msg, additionalData...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// MarshalAuthMsg initial authentication message
|
||||
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
|
||||
if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
|
||||
return nil, fmt.Errorf("too large auth payload")
|
||||
}
|
||||
|
||||
msg := make([]byte, headerTotalSizeAuth+len(authPayload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||
copy(msg[offsetAuthPeerID:], peerID[:])
|
||||
copy(msg[headerTotalSizeAuth:], authPayload)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||
func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
|
||||
if len(msg) < headerTotalSizeAuth {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
// Validate the magic header
|
||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
|
||||
return &peerID, msg[headerTotalSizeAuth:], nil
|
||||
}
|
||||
|
||||
// MarshalAuthResponse creates a response message to the auth.
|
||||
// In case of success connection the server response with a AuthResponse message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
|
||||
msg = append(msg, ab...)
|
||||
|
||||
if len(msg) > MaxHandshakeRespSize {
|
||||
return nil, fmt.Errorf("invalid message length: %d", len(msg))
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < sizeOfProtoHeader+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg[sizeOfProtoHeader:]), nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
return []byte{
|
||||
byte(CurrentProtocolVersion),
|
||||
byte(MsgTypeClose),
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalTransportMsg creates a transport message.
|
||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||
// destination peer hashed ID.
|
||||
func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
|
||||
// todo validate size
|
||||
msg := make([]byte, headerTotalSizeTransport+len(payload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
copy(msg[sizeOfProtoHeader:], peerID[:])
|
||||
copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||
func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
const offsetEnd = offsetTransportID + peerIDSize
|
||||
var peerID PeerID
|
||||
copy(peerID[:], buf[offsetTransportID:offsetEnd])
|
||||
return &peerID, buf[headerTotalSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
func UnmarshalTransportID(buf []byte) (*PeerID, error) {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
const offsetEnd = offsetTransportID + peerIDSize
|
||||
var id PeerID
|
||||
copy(id[:], buf[offsetTransportID:offsetEnd])
|
||||
return &id, nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||
// need to allocate a new byte slice.
|
||||
func UpdateTransportMsg(msg []byte, peerID PeerID) error {
|
||||
if len(msg) < offsetTransportID+peerIDSize {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg[offsetTransportID:], peerID[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalHealthcheck creates a health check message.
|
||||
// Health check message is sent by the server periodically. The client will respond with a health check response
|
||||
// message. If the client does not respond to the health check message, the server will close the connection.
|
||||
func MarshalHealthcheck() []byte {
|
||||
return healthCheckMsg
|
||||
}
|
||||
138
shared/relay/messages/message_test.go
Normal file
138
shared/relay/messages/message_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
msg, err := MarshalHelloMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHello {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if receivedPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuth {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if receivedPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthResponse(t *testing.T) {
|
||||
address := "myaddress"
|
||||
msg, err := MarshalAuthResponse(address)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuthResponse {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
|
||||
}
|
||||
|
||||
respAddr, err := UnmarshalAuthResponse(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if respAddr != address {
|
||||
t.Errorf("expected %s, got %s", address, respAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
msg, err := MarshalTransportMsg(peerID, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeTransport {
|
||||
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
|
||||
}
|
||||
|
||||
uPeerID, err := UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||
}
|
||||
|
||||
if uPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if id.String() != peerID.String() {
|
||||
t.Errorf("expected: '%s', got: '%s'", peerID, id)
|
||||
}
|
||||
|
||||
if string(respPayload) != string(payload) {
|
||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalHealthcheck(t *testing.T) {
|
||||
msg := MarshalHealthcheck()
|
||||
|
||||
_, err := ValidateVersion(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHealthCheck {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
|
||||
}
|
||||
}
|
||||
92
shared/relay/messages/peer_state.go
Normal file
92
shared/relay/messages/peer_state.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
|
||||
}
|
||||
|
||||
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
|
||||
}
|
||||
|
||||
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
|
||||
}
|
||||
|
||||
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
|
||||
}
|
||||
|
||||
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
|
||||
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, fmt.Errorf("no list of peer ids provided")
|
||||
}
|
||||
|
||||
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
|
||||
var messages [][]byte
|
||||
|
||||
for i := 0; i < len(ids); i += maxPeersPerMessage {
|
||||
end := i + maxPeersPerMessage
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
chunk := ids[i:end]
|
||||
|
||||
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
|
||||
buf := make([]byte, totalSize)
|
||||
buf[0] = byte(CurrentProtocolVersion)
|
||||
buf[1] = msgType
|
||||
|
||||
offset := sizeOfProtoHeader
|
||||
for _, id := range chunk {
|
||||
copy(buf[offset:], id[:])
|
||||
offset += peerIDSize
|
||||
}
|
||||
|
||||
messages = append(messages, buf)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
|
||||
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
|
||||
if len(buf) < sizeOfProtoHeader {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
|
||||
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
|
||||
}
|
||||
|
||||
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
|
||||
|
||||
ids := make([]PeerID, numIDs)
|
||||
offset := sizeOfProtoHeader
|
||||
for i := 0; i < numIDs; i++ {
|
||||
copy(ids[i][:], buf[offset:offset+peerIDSize])
|
||||
offset += peerIDSize
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
144
shared/relay/messages/peer_state_test.go
Normal file
144
shared/relay/messages/peer_state_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
testPeerCount = 10
|
||||
)
|
||||
|
||||
// Helper function to generate test PeerIDs
|
||||
func generateTestPeerIDs(n int) []PeerID {
|
||||
ids := make([]PeerID, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < peerIDSize; j++ {
|
||||
ids[i][j] = byte(i + j)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Helper function to compare slices of PeerID
|
||||
func peerIDEqual(a, b []PeerID) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if !bytes.Equal(a[i][:], b[i][:]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalSubPeerStateMsg(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
decoded, err := UnmarshalSubPeerStateMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalSubPeerStateMsg([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
|
||||
// Too short
|
||||
_, err := UnmarshalSubPeerStateMsg([]byte{1})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for short input")
|
||||
}
|
||||
|
||||
// Misaligned length
|
||||
buf := make([]byte, sizeOfProtoHeader+1)
|
||||
_, err = UnmarshalSubPeerStateMsg(buf)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for misaligned input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalPeersOnline(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalPeersOnline([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
|
||||
_, err := UnmarshalPeersOnlineMsg([]byte{1})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for short input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalPeersWentOffline(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
|
||||
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalPeersWentOffline([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user