get all proxy endpoints when a proxy connects

This commit is contained in:
Alisdair MacLeod
2026-01-28 16:55:05 +00:00
parent 95bf97dc3c
commit c98dcf5ef9
3 changed files with 37 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/activity"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
@@ -18,10 +19,12 @@ import (
)
type reverseProxyStore interface {
GetReverseProxies(ctx context.Context, lockStrength store.LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength store.LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
}
type keyStore interface {
GetGroupByName(ctx context.Context, groupName string, accountID string) (*types.Group, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
}
@@ -114,7 +117,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
// sendSnapshot sends the initial snapshot of all reverse proxies to proxy
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, "accountID") // TODO: check locking strength and accountID.
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength.
if err != nil {
// TODO: something?
return fmt.Errorf("get account reverse proxies from store: %w", err)
@@ -160,20 +163,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
})
}
group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID)
if err != nil {
// TODO: log this?
continue
}
// TODO: should this even be here? We're running in a loop, and on each proxy, this will create a LOT of setup key entries that we currently have no way to remove.
key, err := s.keyStore.CreateSetupKey(ctx,
"accountID", // TODO: get an account ID from somewhere, likely needs to be passed in from higher up.
"keyname", // TODO: define a sensible key name to make cleanup easier.
types.SetupKeyOneOff, // TODO: is this correct? Might make cleanup simpler and we're going to generate a new key every time the proxy connects.
time.Minute, // TODO: only provide just enough time for the proxy to make the connection before this key becomes invalid. Should help with cleanup as well as protection against these leaking in transit.
[]string{"auto", "groups"}, // TODO: join a group for proxy to simplify adding rules to proxies?
1, // TODO: usage limit, how is this different from the OneOff key type?
"userID", // TODO: use a set userID for proxy peers?
false, // TODO: ephemeral peers are different...right?
false, // TODO: not sure but I think this should be false.
rp.AccountID,
rp.Name,
types.SetupKeyReusable,
0,
[]string{group.ID},
0,
activity.SystemInitiator,
true,
false,
)
if err != nil {
// TODO: how to handle this?
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{

View File

@@ -4675,6 +4675,22 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
return proxy, nil
}
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
return proxyList, nil
}
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -248,6 +248,7 @@ type Store interface {
DeleteReverseProxy(ctx context.Context, accountID, serviceID string) error
GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error)
GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error)
GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)