get all proxy endpoints when a proxy connects

This commit is contained in:
Alisdair MacLeod
2026-01-28 16:55:05 +00:00
parent 95bf97dc3c
commit c98dcf5ef9
3 changed files with 37 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/activity"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
@@ -18,10 +19,12 @@ import (
) )
type reverseProxyStore interface { type reverseProxyStore interface {
GetReverseProxies(ctx context.Context, lockStrength store.LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength store.LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) GetAccountReverseProxies(ctx context.Context, lockStrength store.LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
} }
type keyStore interface { type keyStore interface {
GetGroupByName(ctx context.Context, groupName string, accountID string) (*types.Group, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
} }
@@ -114,7 +117,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
// sendSnapshot sends the initial snapshot of all reverse proxies to proxy // sendSnapshot sends the initial snapshot of all reverse proxies to proxy
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, "accountID") // TODO: check locking strength and accountID. reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength.
if err != nil { if err != nil {
// TODO: something? // TODO: something?
return fmt.Errorf("get account reverse proxies from store: %w", err) return fmt.Errorf("get account reverse proxies from store: %w", err)
@@ -160,20 +163,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
}) })
} }
group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID)
if err != nil {
// TODO: log this?
continue
}
// TODO: should this even be here? We're running in a loop, and on each proxy, this will create a LOT of setup key entries that we currently have no way to remove. // TODO: should this even be here? We're running in a loop, and on each proxy, this will create a LOT of setup key entries that we currently have no way to remove.
key, err := s.keyStore.CreateSetupKey(ctx, key, err := s.keyStore.CreateSetupKey(ctx,
"accountID", // TODO: get an account ID from somewhere, likely needs to be passed in from higher up. rp.AccountID,
"keyname", // TODO: define a sensible key name to make cleanup easier. rp.Name,
types.SetupKeyOneOff, // TODO: is this correct? Might make cleanup simpler and we're going to generate a new key every time the proxy connects. types.SetupKeyReusable,
time.Minute, // TODO: only provide just enough time for the proxy to make the connection before this key becomes invalid. Should help with cleanup as well as protection against these leaking in transit. 0,
[]string{"auto", "groups"}, // TODO: join a group for proxy to simplify adding rules to proxies? []string{group.ID},
1, // TODO: usage limit, how is this different from the OneOff key type? 0,
"userID", // TODO: use a set userID for proxy peers? activity.SystemInitiator,
false, // TODO: ephemeral peers are different...right? true,
false, // TODO: not sure but I think this should be false. false,
) )
if err != nil { if err != nil {
// TODO: how to handle this? // TODO: how to handle this?
continue
} }
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{

View File

@@ -4675,6 +4675,22 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
return proxy, nil return proxy, nil
} }
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
return proxyList, nil
}
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) { func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {

View File

@@ -248,6 +248,7 @@ type Store interface {
DeleteReverseProxy(ctx context.Context, accountID, serviceID string) error DeleteReverseProxy(ctx context.Context, accountID, serviceID string) error
GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error)
GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error)
GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)