Extract common server encryption logic (#65)

* refactor: extract common message encryption logic
* refactor: move letsencrypt logic to common
* refactor: rename common package to encryption
* test: add encryption tests
This commit is contained in:
Mikhail Bragin
2021-07-22 15:23:24 +02:00
committed by GitHub
parent c98be683bf
commit 2172d6f1b9
16 changed files with 343 additions and 141 deletions

50
encryption/encryption.go Normal file
View File

@@ -0,0 +1,50 @@
package encryption
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/nacl/box"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
// Wireguard keys are used for encryption
// Encrypt encrypts a message using local Wireguard private key and remote peer's public key.
func Encrypt(msg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.Key) ([]byte, error) {
nonce, err := genNonce()
if err != nil {
return nil, err
}
return box.Seal(nonce[:], msg, nonce, toByte32(peerPublicKey), toByte32(privateKey)), nil
}
// Decrypt decrypts a message that has been encrypted by the remote peer using Wireguard private key and remote peer's public key.
func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.Key) ([]byte, error) {
nonce, err := genNonce()
if err != nil {
return nil, err
}
copy(nonce[:], encryptedMsg[:24])
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
if !ok {
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
}
return opened, nil
}
// Generates nonce of size 24
func genNonce() (*[24]byte, error) {
var nonce [24]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, err
}
return &nonce, nil
}
// Converts Wireguard key to byte array of size 32 (a format used by the golang crypto package)
func toByte32(key wgtypes.Key) *[32]byte {
return (*[32]byte)(&key)
}

View File

@@ -0,0 +1,13 @@
package encryption_test
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestManagement(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Management Service Suite")
}

View File

@@ -0,0 +1,60 @@
package encryption_test
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/encryption/testprotos"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const ()
var _ = Describe("Encryption", func() {
var (
encryptionKey wgtypes.Key
decryptionKey wgtypes.Key
)
BeforeEach(func() {
var err error
encryptionKey, err = wgtypes.GenerateKey()
Expect(err).NotTo(HaveOccurred())
decryptionKey, err = wgtypes.GenerateKey()
Expect(err).NotTo(HaveOccurred())
})
Context("decrypting a plain message", func() {
Context("when it was encrypted with Wireguard keys", func() {
Specify("should be successful", func() {
msg := "message"
encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey)
Expect(err).NotTo(HaveOccurred())
decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey)
Expect(err).NotTo(HaveOccurred())
Expect(string(decryptedMsg)).To(BeEquivalentTo(msg))
})
})
})
Context("decrypting a protobuf message", func() {
Context("when it was encrypted with Wireguard keys", func() {
Specify("should be successful", func() {
protoMsg := &testprotos.TestMessage{Body: "message"}
encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg)
Expect(err).NotTo(HaveOccurred())
decryptedMsg := &testprotos.TestMessage{}
err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg)
Expect(err).NotTo(HaveOccurred())
Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody()))
})
})
})
})

40
encryption/letsencrypt.go Normal file
View File

@@ -0,0 +1,40 @@
package encryption
import (
"crypto/tls"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
"net/http"
"os"
"path/filepath"
)
// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate.
// Includes a HTTP handler and listener to solve the Let's encrypt challenge
func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
err = os.MkdirAll(certDir, os.ModeDir)
if err != nil {
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
}
}
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
}
// listener to handle Let's encrypt certificate challenge
go func() {
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
log.Fatalf("failed to serve letsencrypt handler: %v", err)
}
}()
return &tls.Config{GetCertificate: certManager.GetCertificate}
}

40
encryption/message.go Normal file
View File

@@ -0,0 +1,40 @@
package encryption
import (
pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// EncryptMessage encrypts a body of the given protobuf Message
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
log.Errorf("failed marshalling message %v", err)
return nil, err
}
encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey)
if err != nil {
log.Errorf("failed encrypting SyncResponse %v", err)
return nil, err
}
return encryptedBytes, nil
}
// DecryptMessage decrypts an encrypted message into given protobuf Message
func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error {
decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey)
if err != nil {
log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String())
return err
}
err = pb.Unmarshal(decrypted, message)
if err != nil {
log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String())
return err
}
return nil
}

View File

@@ -0,0 +1,2 @@
#!/bin/bash
protoc -I testprotos/ testprotos/testproto.proto --go_out=.

View File

@@ -0,0 +1,142 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// source: testproto.proto
package testprotos
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type TestMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
}
func (x *TestMessage) Reset() {
*x = TestMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_testproto_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TestMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TestMessage) ProtoMessage() {}
func (x *TestMessage) ProtoReflect() protoreflect.Message {
mi := &file_testproto_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead.
func (*TestMessage) Descriptor() ([]byte, []int) {
return file_testproto_proto_rawDescGZIP(), []int{0}
}
func (x *TestMessage) GetBody() string {
if x != nil {
return x.Body
}
return ""
}
var File_testproto_proto protoreflect.FileDescriptor
var file_testproto_proto_rawDesc = []byte{
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_testproto_proto_rawDescOnce sync.Once
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
)
func file_testproto_proto_rawDescGZIP() []byte {
file_testproto_proto_rawDescOnce.Do(func() {
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
})
return file_testproto_proto_rawDescData
}
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_testproto_proto_goTypes = []interface{}{
(*TestMessage)(nil), // 0: testprotos.TestMessage
}
var file_testproto_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_testproto_proto_init() }
func file_testproto_proto_init() {
if File_testproto_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TestMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_testproto_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_testproto_proto_goTypes,
DependencyIndexes: file_testproto_proto_depIdxs,
MessageInfos: file_testproto_proto_msgTypes,
}.Build()
File_testproto_proto = out.File
file_testproto_proto_rawDesc = nil
file_testproto_proto_goTypes = nil
file_testproto_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
option go_package = "/testprotos";
package testprotos;
message TestMessage {
string body = 1;
}