diff --git a/relay/client/client.go b/relay/client/client.go index 1ba815018..6b8dac9df 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -371,7 +371,11 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { return 0, io.EOF } */ - msg := messages.MarshalTransportMsg(dstID, payload) + msg, err := messages.MarshalTransportMsg(dstID, payload) + if err != nil { + log.Errorf("failed to marshal transport message: %s", err) + return 0, err + } n, err := c.relayConn.Write(msg) if err != nil { log.Errorf("failed to write transport message: %s", err) diff --git a/relay/client/manager.go b/relay/client/manager.go index ae44be4b9..42cbcd280 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -194,6 +194,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) { if err != nil { return false, fmt.Errorf("relay client not connected") } + log.Debugf("check if foreign server: %s != %s", rAddr.String(), address) return rAddr.String() != address, nil } diff --git a/relay/messages/id.go b/relay/messages/id.go index c2248fb2b..dbb49cc58 100644 --- a/relay/messages/id.go +++ b/relay/messages/id.go @@ -3,18 +3,25 @@ package messages import ( "crypto/sha256" "encoding/base64" + "fmt" ) const ( - IDSize = sha256.Size + prefixLength = 4 + IDSize = sha256.Size + 4 // 4 is equal with len(prefix) +) + +var ( + prefix = []byte("sha-") // 4 bytes ) func HashID(peerID string) ([]byte, string) { idHash := sha256.Sum256([]byte(peerID)) - idHashString := base64.StdEncoding.EncodeToString(idHash[:]) - return idHash[:], idHashString + idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) + prefixedHash := append(prefix, idHash[:]...) + return prefixedHash, idHashString } func HashIDToString(idHash []byte) string { - return base64.StdEncoding.EncodeToString(idHash) + return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) } diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go new file mode 100644 index 000000000..25c64955b --- /dev/null +++ b/relay/messages/id_test.go @@ -0,0 +1,26 @@ +package messages + +import ( + "encoding/binary" + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestHashID(t *testing.T) { + hashedID, hashedStringId := HashID("abc") + enc := HashIDToString(hashedID) + if enc != hashedStringId { + t.Errorf("expected %s, got %s", hashedStringId, enc) + } + + var magicHeader uint32 = 0x2112A442 // size 4 byte + + msg := make([]byte, 4) + binary.BigEndian.PutUint32(msg, magicHeader) + + magicHeader2 := []byte{0x21, 0x12, 0xA4, 0x42} + + log.Infof("msg: %v, %v", msg, magicHeader2) + +} diff --git a/relay/messages/message.go b/relay/messages/message.go index b865687ce..92b8bced7 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -1,6 +1,7 @@ package messages import ( + "bytes" "fmt" log "github.com/sirupsen/logrus" @@ -11,10 +12,15 @@ const ( MsgTypeHelloResponse MsgType = 1 MsgTypeTransport MsgType = 2 MsgClose MsgType = 3 + + headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID + headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID ) var ( ErrInvalidMessageLength = fmt.Errorf("invalid message length") + + magicHeader = []byte{0x21, 0x12, 0xA4, 0x42} ) type MsgType byte @@ -35,7 +41,6 @@ func (m MsgType) String() string { } func DetermineClientMsgType(msg []byte) (MsgType, error) { - // todo: validate magic byte msgType := MsgType(msg[0]) switch msgType { case MsgTypeHello: @@ -50,7 +55,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) { } func DetermineServerMsgType(msg []byte) (MsgType, error) { - // todo: validate magic byte msgType := MsgType(msg[0]) switch msgType { case MsgTypeHelloResponse: @@ -67,19 +71,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { // MarshalHelloMsg initial hello message func MarshalHelloMsg(peerID []byte) ([]byte, error) { if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length") + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, 1, 1+len(peerID)) + msg := make([]byte, 5, headerSizeHello) msg[0] = byte(MsgTypeHello) + copy(msg[1:5], magicHeader) msg = append(msg, peerID...) return msg, nil } func UnmarshalHelloMsg(msg []byte) ([]byte, error) { - if len(msg) < 2 { + if len(msg) < headerSizeHello { return nil, fmt.Errorf("invalid 'hello' messge") } - return msg[1:], nil + bytes.Equal(msg[1:5], magicHeader) + return msg[5:], nil } func MarshalHelloResponse() []byte { @@ -98,34 +104,32 @@ func MarshalCloseMsg() []byte { // Transport message -func MarshalTransportMsg(peerID []byte, payload []byte) []byte { +func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { if len(peerID) != IDSize { - return nil + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, 1+IDSize, 1+IDSize+len(payload)) + msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload)) msg[0] = byte(MsgTypeTransport) copy(msg[1:], peerID) msg = append(msg, payload...) - return msg + return msg, nil } func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { - headerSize := 1 + IDSize - if len(buf) < headerSize { + if len(buf) < headerSizeTransport { return nil, nil, ErrInvalidMessageLength } - return buf[1:headerSize], buf[headerSize:], nil + return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil } func UnmarshalTransportID(buf []byte) ([]byte, error) { - headerSize := 1 + IDSize - if len(buf) < headerSize { - log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf) + if len(buf) < headerSizeTransport { + log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf) return nil, ErrInvalidMessageLength } - return buf[1:headerSize], nil + return buf[1:headerSizeTransport], nil } func UpdateTransportMsg(msg []byte, peerID []byte) error { diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go new file mode 100644 index 000000000..d40963d8b --- /dev/null +++ b/relay/messages/message_test.go @@ -0,0 +1,43 @@ +package messages + +import ( + "testing" +) + +func TestMarshalHelloMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + bHello, err := MarshalHelloMsg(peerID) + if err != nil { + t.Fatalf("error: %v", err) + } + + receivedPeerID, err := UnmarshalHelloMsg(bHello) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(receivedPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, receivedPeerID) + } +} + +func TestMarshalTransportMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + payload := []byte("payload") + msg, err := MarshalTransportMsg(peerID, payload) + if err != nil { + t.Fatalf("error: %v", err) + } + + id, respPayload, err := UnmarshalTransportMsg(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if string(id) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, id) + } + + if string(respPayload) != string(payload) { + t.Errorf("expected %s, got %s", payload, respPayload) + } +}