mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Extract common server encryption logic (#65)
* refactor: extract common message encryption logic * refactor: move letsencrypt logic to common * refactor: rename common package to encryption * test: add encryption tests
This commit is contained in:
50
encryption/encryption.go
Normal file
50
encryption/encryption.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||
// Wireguard keys are used for encryption
|
||||
|
||||
// Encrypt encrypts a message using local Wireguard private key and remote peer's public key.
|
||||
func Encrypt(msg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.Key) ([]byte, error) {
|
||||
nonce, err := genNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return box.Seal(nonce[:], msg, nonce, toByte32(peerPublicKey), toByte32(privateKey)), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message that has been encrypted by the remote peer using Wireguard private key and remote peer's public key.
|
||||
func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.Key) ([]byte, error) {
|
||||
nonce, err := genNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(nonce[:], encryptedMsg[:24])
|
||||
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
||||
}
|
||||
|
||||
return opened, nil
|
||||
}
|
||||
|
||||
// Generates nonce of size 24
|
||||
func genNonce() (*[24]byte, error) {
|
||||
var nonce [24]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &nonce, nil
|
||||
}
|
||||
|
||||
// Converts Wireguard key to byte array of size 32 (a format used by the golang crypto package)
|
||||
func toByte32(key wgtypes.Key) *[32]byte {
|
||||
return (*[32]byte)(&key)
|
||||
}
|
||||
13
encryption/encryption_suite_test.go
Normal file
13
encryption/encryption_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package encryption_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManagement(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Management Service Suite")
|
||||
}
|
||||
60
encryption/encryption_test.go
Normal file
60
encryption/encryption_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package encryption_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/encryption/testprotos"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const ()
|
||||
|
||||
var _ = Describe("Encryption", func() {
|
||||
|
||||
var (
|
||||
encryptionKey wgtypes.Key
|
||||
decryptionKey wgtypes.Key
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
encryptionKey, err = wgtypes.GenerateKey()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
decryptionKey, err = wgtypes.GenerateKey()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("decrypting a plain message", func() {
|
||||
Context("when it was encrypted with Wireguard keys", func() {
|
||||
Specify("should be successful", func() {
|
||||
msg := "message"
|
||||
encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(string(decryptedMsg)).To(BeEquivalentTo(msg))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("decrypting a protobuf message", func() {
|
||||
Context("when it was encrypted with Wireguard keys", func() {
|
||||
Specify("should be successful", func() {
|
||||
|
||||
protoMsg := &testprotos.TestMessage{Body: "message"}
|
||||
encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
decryptedMsg := &testprotos.TestMessage{}
|
||||
err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
40
encryption/letsencrypt.go
Normal file
40
encryption/letsencrypt.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate.
|
||||
// Includes a HTTP handler and listener to solve the Let's encrypt challenge
|
||||
func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config {
|
||||
certDir := filepath.Join(datadir, "letsencrypt")
|
||||
|
||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(certDir, os.ModeDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
|
||||
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
|
||||
}
|
||||
|
||||
// listener to handle Let's encrypt certificate challenge
|
||||
go func() {
|
||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &tls.Config{GetCertificate: certManager.GetCertificate}
|
||||
}
|
||||
40
encryption/message.go
Normal file
40
encryption/message.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// EncryptMessage encrypts a body of the given protobuf Message
|
||||
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
|
||||
byteResp, err := pb.Marshal(message)
|
||||
if err != nil {
|
||||
log.Errorf("failed marshalling message %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting SyncResponse %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return encryptedBytes, nil
|
||||
}
|
||||
|
||||
// DecryptMessage decrypts an encrypted message into given protobuf Message
|
||||
func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error {
|
||||
decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey)
|
||||
if err != nil {
|
||||
log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String())
|
||||
return err
|
||||
}
|
||||
|
||||
err = pb.Unmarshal(decrypted, message)
|
||||
if err != nil {
|
||||
log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
2
encryption/testprotos/generate.sh
Executable file
2
encryption/testprotos/generate.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
protoc -I testprotos/ testprotos/testproto.proto --go_out=.
|
||||
142
encryption/testprotos/testproto.pb.go
Normal file
142
encryption/testprotos/testproto.pb.go
Normal file
@@ -0,0 +1,142 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.12.4
|
||||
// source: testproto.proto
|
||||
|
||||
package testprotos
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type TestMessage struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (x *TestMessage) Reset() {
|
||||
*x = TestMessage{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *TestMessage) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*TestMessage) ProtoMessage() {}
|
||||
|
||||
func (x *TestMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead.
|
||||
func (*TestMessage) Descriptor() ([]byte, []int) {
|
||||
return file_testproto_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *TestMessage) GetBody() string {
|
||||
if x != nil {
|
||||
return x.Body
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_testproto_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_testproto_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
|
||||
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
|
||||
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_testproto_proto_rawDescOnce sync.Once
|
||||
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_testproto_proto_rawDescGZIP() []byte {
|
||||
file_testproto_proto_rawDescOnce.Do(func() {
|
||||
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
|
||||
})
|
||||
return file_testproto_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_testproto_proto_goTypes = []interface{}{
|
||||
(*TestMessage)(nil), // 0: testprotos.TestMessage
|
||||
}
|
||||
var file_testproto_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_testproto_proto_init() }
|
||||
func file_testproto_proto_init() {
|
||||
if File_testproto_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*TestMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_testproto_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_testproto_proto_goTypes,
|
||||
DependencyIndexes: file_testproto_proto_depIdxs,
|
||||
MessageInfos: file_testproto_proto_msgTypes,
|
||||
}.Build()
|
||||
File_testproto_proto = out.File
|
||||
file_testproto_proto_rawDesc = nil
|
||||
file_testproto_proto_goTypes = nil
|
||||
file_testproto_proto_depIdxs = nil
|
||||
}
|
||||
9
encryption/testprotos/testproto.proto
Normal file
9
encryption/testprotos/testproto.proto
Normal file
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "/testprotos";
|
||||
|
||||
package testprotos;
|
||||
|
||||
message TestMessage {
|
||||
string body = 1;
|
||||
}
|
||||
Reference in New Issue
Block a user