mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-28 05:06:38 +00:00
Compare commits
9 Commits
v0.21.5
...
feature/ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d68a4a7d21 | ||
|
|
ca1722ed10 | ||
|
|
649dbf2bed | ||
|
|
70076b98d2 | ||
|
|
551455f314 | ||
|
|
a6431e053b | ||
|
|
520c7b5d37 | ||
|
|
e376541745 | ||
|
|
774d8e955c |
@@ -42,7 +42,7 @@ type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
fakeResolverWG sync.WaitGroup
|
||||
udpFilterHookID string
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registeredHandlerMap
|
||||
@@ -105,7 +105,10 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||
}
|
||||
|
||||
defaultServer.evalRuntimeAddress()
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
defaultServer.evelRuntimeAddressForUserspace()
|
||||
}
|
||||
|
||||
return defaultServer, nil
|
||||
}
|
||||
|
||||
@@ -118,6 +121,9 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !s.wgInterface.IsUserspaceBind() {
|
||||
s.evalRuntimeAddress()
|
||||
}
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
}
|
||||
@@ -126,17 +132,8 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
func (s *DefaultServer) listen() {
|
||||
// nil check required in unit tests
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.fakeResolverWG.Add(1)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
hookID := s.filterDNSTraffic()
|
||||
s.fakeResolverWG.Wait()
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}()
|
||||
s.udpFilterHookID = s.filterDNSTraffic()
|
||||
s.setListenerStatus(true)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -153,6 +150,10 @@ func (s *DefaultServer) listen() {
|
||||
}()
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
//
|
||||
// When kernel space interface used it return real DNS server listener IP address
|
||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||
func (s *DefaultServer) DnsIP() string {
|
||||
if !s.enabled {
|
||||
return ""
|
||||
@@ -201,10 +202,6 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
s.fakeResolverWG.Done()
|
||||
}
|
||||
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@@ -212,6 +209,18 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
// udpFilterHookID here empty only in the unit tests
|
||||
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
||||
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}
|
||||
s.udpFilterHookID = ""
|
||||
s.listenerIsRunning = false
|
||||
return nil
|
||||
}
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
@@ -275,12 +284,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.fakeResolverWG.Done()
|
||||
} else {
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.listen()
|
||||
@@ -555,17 +560,17 @@ func (s *DefaultServer) filterDNSTraffic() string {
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = defaultPort
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evalRuntimeAddress() {
|
||||
defer func() {
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}()
|
||||
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = defaultPort
|
||||
return
|
||||
}
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
|
||||
@@ -5,15 +5,18 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
@@ -241,7 +244,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
// pretend we are running
|
||||
dnsServer.listenerIsRunning = true
|
||||
dnsServer.fakeResolverWG.Add(1)
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
@@ -276,6 +278,133 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("crate and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||
if err != nil {
|
||||
t.Errorf("parse CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -888,7 +888,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
err := e.signal.Receive(func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn := e.peerConns[msg.Key]
|
||||
if conn == nil {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
|
||||
5
keepalive/client.go
Normal file
5
keepalive/client.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package keepalive
|
||||
|
||||
func IsKeepAliveMsg(body []byte) bool {
|
||||
return len(body) == 0
|
||||
}
|
||||
155
keepalive/keep_alive.go
Normal file
155
keepalive/keep_alive.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package keepalive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
GrpcVersionHeaderKey = "x-netbird-version"
|
||||
|
||||
reversProxyHeaderKey = "x-netbird-peer"
|
||||
keepAliveInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
type KeepAlive struct {
|
||||
sync.RWMutex
|
||||
ticker *time.Ticker
|
||||
done chan struct{}
|
||||
streams map[string]*ioMonitor
|
||||
keepAliveMsg interface{}
|
||||
}
|
||||
|
||||
func NewKeepAlive(keepAliveMsg interface{}) *KeepAlive {
|
||||
ka := &KeepAlive{
|
||||
ticker: time.NewTicker(1 * time.Second),
|
||||
done: make(chan struct{}),
|
||||
streams: make(map[string]*ioMonitor),
|
||||
keepAliveMsg: keepAliveMsg,
|
||||
}
|
||||
go ka.start()
|
||||
return ka
|
||||
}
|
||||
|
||||
func (k *KeepAlive) StreamInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
address, supported := k.keepAliveIsSupported(stream.Context())
|
||||
if !supported {
|
||||
return handler(srv, stream)
|
||||
}
|
||||
|
||||
m := &ioMonitor{
|
||||
sync.Mutex{},
|
||||
sync.Mutex{},
|
||||
stream,
|
||||
time.Now(),
|
||||
}
|
||||
|
||||
k.addIoMonitor(address, m)
|
||||
|
||||
return handler(srv, m)
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeepAlive) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
address, supported := k.keepAliveIsSupported(ctx)
|
||||
if supported {
|
||||
k.updateLastSeen(address)
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeepAlive) Stop() {
|
||||
select {
|
||||
case k.done <- struct{}{}:
|
||||
k.ticker.Stop()
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeepAlive) start() {
|
||||
for {
|
||||
select {
|
||||
case <-k.done:
|
||||
return
|
||||
case t := <-k.ticker.C:
|
||||
k.checkKeepAlive(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeepAlive) checkKeepAlive(now time.Time) {
|
||||
k.Lock()
|
||||
defer k.Unlock()
|
||||
for addr, m := range k.streams {
|
||||
if k.isKeepAliveOutDated(now, m) {
|
||||
continue
|
||||
}
|
||||
err := k.sendKeepAlive(m)
|
||||
if err != nil {
|
||||
log.Debugf("stop keepalive for: %s", addr)
|
||||
delete(k.streams, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeepAlive) keepAliveIsSupported(ctx context.Context) (string, bool) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
log.Warnf("metadata not found")
|
||||
return "", false
|
||||
}
|
||||
|
||||
peerAddress := k.addressFromHeader(md)
|
||||
if peerAddress == "" {
|
||||
log.Debugf("peer is not using reverse proxy")
|
||||
return "", false
|
||||
}
|
||||
|
||||
if len(md.Get(GrpcVersionHeaderKey)) == 0 {
|
||||
log.Debugf("version info not found")
|
||||
return "", false
|
||||
}
|
||||
return peerAddress, true
|
||||
}
|
||||
|
||||
func (k *KeepAlive) addIoMonitor(address string, m *ioMonitor) {
|
||||
k.Lock()
|
||||
defer k.Unlock()
|
||||
log.Debugf("add stream address for keepalive list: %s", address)
|
||||
k.streams[address] = m
|
||||
}
|
||||
|
||||
func (k *KeepAlive) sendKeepAlive(m *ioMonitor) error {
|
||||
return m.sendMsg(k.keepAliveMsg)
|
||||
}
|
||||
|
||||
func (k *KeepAlive) updateLastSeen(address string) {
|
||||
k.RLock()
|
||||
m, ok := k.streams[address]
|
||||
k.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m.updateLastSeen()
|
||||
}
|
||||
|
||||
func (k *KeepAlive) addressFromHeader(md metadata.MD) string {
|
||||
peer := md.Get(reversProxyHeaderKey)
|
||||
if len(peer) == 0 {
|
||||
return ""
|
||||
}
|
||||
return peer[0]
|
||||
}
|
||||
|
||||
func (k *KeepAlive) isKeepAliveOutDated(now time.Time, m *ioMonitor) bool {
|
||||
return now.Sub(m.getLastSeen()) < keepAliveInterval
|
||||
}
|
||||
35
keepalive/monitor.go
Normal file
35
keepalive/monitor.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package keepalive
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type ioMonitor struct {
|
||||
mu sync.Mutex
|
||||
streamLock sync.Mutex
|
||||
grpc.ServerStream
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
func (l *ioMonitor) sendMsg(m interface{}) error {
|
||||
l.updateLastSeen()
|
||||
l.streamLock.Lock()
|
||||
defer l.streamLock.Unlock()
|
||||
return l.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (l *ioMonitor) updateLastSeen() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.lastSeen = time.Now()
|
||||
}
|
||||
|
||||
func (l *ioMonitor) getLastSeen() time.Time {
|
||||
l.mu.Lock()
|
||||
t := l.lastSeen
|
||||
l.mu.Unlock()
|
||||
return t
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -23,7 +24,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
appKeepAlive "github.com/netbirdio/netbird/keepalive"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorders
|
||||
@@ -67,6 +70,9 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
|
||||
realClient := proto.NewManagementServiceClient(conn)
|
||||
|
||||
md := metadata.Pairs(appKeepAlive.GrpcVersionHeaderKey, version.NetbirdVersion())
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
return &GrpcClient{
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
@@ -131,6 +137,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
|
||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
@@ -246,6 +253,10 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
|
||||
return err
|
||||
}
|
||||
|
||||
if appKeepAlive.IsKeepAliveMsg(update.Body) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("got an update message from Management Service")
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
grpcKeepAlive "github.com/netbirdio/netbird/keepalive"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -202,11 +203,17 @@ var (
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||
}
|
||||
|
||||
ka := grpcKeepAlive.NewKeepAlive(&mgmtProto.KeepAlive{})
|
||||
defer ka.Stop()
|
||||
sInterc := grpc.StreamInterceptor(ka.StreamInterceptor())
|
||||
uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor())
|
||||
gRPCOpts = append(gRPCOpts, sInterc, uInterc)
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
installationID, err := getInstallationID(store)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -329,3 +329,5 @@ message FirewallRule {
|
||||
ICMP = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message KeepAlive {}
|
||||
|
||||
@@ -21,7 +21,9 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
appKeepAlive "github.com/netbirdio/netbird/keepalive"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const defaultSendTimeout = 5 * time.Second
|
||||
@@ -73,6 +75,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
|
||||
sigCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
sigCtx,
|
||||
addr,
|
||||
@@ -87,9 +90,14 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
log.Errorf("failed to connect to the signalling server %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to Signal Service: %v", conn.Target())
|
||||
|
||||
md := metadata.New(map[string]string{
|
||||
proto.HeaderId: key.PublicKey().String(), // add key fingerprint to the request header to be identified on the server side
|
||||
appKeepAlive.GrpcVersionHeaderKey: version.NetbirdVersion(), // add version info to ensure keep alive is supported
|
||||
})
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
return &GrpcClient{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
@@ -146,7 +154,7 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
||||
stream, err := c.connect(ctx)
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
return err
|
||||
@@ -208,13 +216,9 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
||||
return c.connectedCh
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
func (c *GrpcClient) connect(ctx context.Context) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
c.stream = nil
|
||||
|
||||
// add key fingerprint to the request header to be identified on the server side
|
||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||
metaCtx := metadata.NewOutgoingContext(ctx, md)
|
||||
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
|
||||
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
||||
c.stream = stream
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -366,6 +370,12 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if appKeepAlive.IsKeepAliveMsg(msg.Body) {
|
||||
log.Tracef("received keepalive")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
|
||||
decryptedMessage, err := c.decryptMessage(msg)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
@@ -14,15 +13,18 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
appKeepAlive "github.com/netbirdio/netbird/keepalive"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -93,6 +95,13 @@ var (
|
||||
}
|
||||
|
||||
opts = append(opts, signalKaep, signalKasp)
|
||||
|
||||
ka := appKeepAlive.NewKeepAlive(&proto.KeepAlive{})
|
||||
defer ka.Stop()
|
||||
sInterc := grpc.StreamInterceptor(ka.StreamInterceptor())
|
||||
uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor())
|
||||
opts = append(opts, sInterc, uInterc)
|
||||
|
||||
grpcServer := grpc.NewServer(opts...)
|
||||
proto.RegisterSignalExchangeServer(grpcServer, server.NewServer())
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.9
|
||||
// protoc v3.21.12
|
||||
// source: signalexchange.proto
|
||||
|
||||
package proto
|
||||
@@ -347,6 +347,44 @@ func (x *Mode) GetDirect() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type KeepAlive struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *KeepAlive) Reset() {
|
||||
*x = KeepAlive{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *KeepAlive) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*KeepAlive) ProtoMessage() {}
|
||||
|
||||
func (x *KeepAlive) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use KeepAlive.ProtoReflect.Descriptor instead.
|
||||
func (*KeepAlive) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
var File_signalexchange_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_signalexchange_proto_rawDesc = []byte{
|
||||
@@ -388,20 +426,20 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x45, 0x10, 0x04, 0x22, 0x2e, 0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64,
|
||||
0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64,
|
||||
0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72,
|
||||
0x65, 0x63, 0x74, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20,
|
||||
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67,
|
||||
0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
|
||||
0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53,
|
||||
0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
|
||||
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42,
|
||||
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x33,
|
||||
0x65, 0x63, 0x74, 0x22, 0x0b, 0x0a, 0x09, 0x4b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x69, 0x76, 0x65,
|
||||
0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61,
|
||||
0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69,
|
||||
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
|
||||
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
|
||||
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
|
||||
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
|
||||
0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65,
|
||||
0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61,
|
||||
0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
|
||||
0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
|
||||
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06,
|
||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -417,13 +455,14 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_signalexchange_proto_goTypes = []interface{}{
|
||||
(Body_Type)(0), // 0: signalexchange.Body.Type
|
||||
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
|
||||
(*Message)(nil), // 2: signalexchange.Message
|
||||
(*Body)(nil), // 3: signalexchange.Body
|
||||
(*Mode)(nil), // 4: signalexchange.Mode
|
||||
(*KeepAlive)(nil), // 5: signalexchange.KeepAlive
|
||||
}
|
||||
var file_signalexchange_proto_depIdxs = []int32{
|
||||
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
|
||||
@@ -494,6 +533,18 @@ func file_signalexchange_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*KeepAlive); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
|
||||
type x struct{}
|
||||
@@ -502,7 +553,7 @@ func file_signalexchange_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_signalexchange_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 4,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -63,3 +63,5 @@ message Body {
|
||||
message Mode {
|
||||
optional bool direct = 1;
|
||||
}
|
||||
|
||||
message KeepAlive {}
|
||||
Reference in New Issue
Block a user