Compare commits

...

12 Commits

Author SHA1 Message Date
crn4
e62521132c Merge remote-tracking branch 'origin/main' into feat/byod-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager.go
#	management/internals/modules/reverseproxy/proxy/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager_mock.go
#	management/internals/shared/grpc/proxy.go
#	management/server/store/sql_store.go
#	proxy/management_integration_test.go
2026-04-13 17:02:24 +03:00
Zoltan Papp
13539543af [client] Fix/grpc retry (#5750)
* [client] Fix flow client Receive retry loop not stopping after Close

Use backoff.Permanent for canceled gRPC errors so Receive returns
immediately instead of retrying until context deadline when the
connection is already closed. Add TestNewClient_PermanentClose to
verify the behavior.

The connectivity.Shutdown check was meaningless because when the connection is
shut down, c.realClient.Events(ctx, grpc.WaitForReady(true)) on the nex line
already fails with codes.Canceled — which is now handled as a permanent error.
The explicit state check was just duplicating what gRPC already reports
through its normal error path.

* [client] remove WaitForReady from stream open call

grpc.WaitForReady(true) parks the RPC call internally until the
connection reaches READY, only unblocking on ctx cancellation.
This means the external backoff.Retry loop in Receive() never gets
control back during a connection outage — it cannot tick, log, or
apply its retry intervals while WaitForReady is blocking.

Removing it restores fail-fast behaviour: Events() returns immediately
with codes.Unavailable when the connection is not ready, which is
exactly what the backoff loop expects. The backoff becomes the single
authority over retry timing and cadence, as originally intended.

* [client] Add connection recreation and improve flow client error handling

Store gRPC dial options on the client to enable connection recreation
on Internal errors (RST_STREAM/PROTOCOL_ERROR). Treat Unauthenticated,
PermissionDenied, and Unimplemented as permanent failures. Unify mutex
usage and add reconnection logging for better observability.

* [client] Remove Unauthenticated, PermissionDenied, and Unimplemented from permanent error handling

* [client] Fix error handling in Receive to properly re-establish stream and improve reconnection messaging

* Fix test

* [client] Add graceful shutdown handling and test for concurrent Close during Receive

Prevent reconnection attempts after client closure by tracking a `closed` flag. Use `backoff.Permanent` for errors caused by operations on a closed client. Add a test to ensure `Close` does not block when `Receive` is actively running.

* [client] Fix connection swap to properly close old gRPC connection

Close the old `gRPC.ClientConn` after successfully swapping to a new connection during reconnection.

* [client] Reset backoff

* [client] Ensure stream closure on error during initialization

* [client] Add test for handling server-side stream closure and reconnection

Introduce `TestReceive_ServerClosesStream` to verify the client's ability to recover and process acknowledgments after the server closes the stream. Enhance test server with a controlled stream closure mechanism.

* [client] Add protocol error simulation and enhance reconnection test

Introduce `connTrackListener` to simulate HTTP/2 RST_STREAM with PROTOCOL_ERROR for testing. Refactor and rename `TestReceive_ServerClosesStream` to `TestReceive_ProtocolErrorStreamReconnect` to verify client recovery on protocol errors.

* [client] Update Close error message in test for clarity

* [client] Fine-tune the tests

* [client] Adjust connection tracking in reconnection test

* [client] Wait for Events handler to exit in RST_STREAM reconnection test

Ensure the old `Events` handler exits fully before proceeding in the reconnection test to avoid dropped acknowledgments on a broken stream. Add a `handlerDone` channel to synchronize handler exits.

* [client] Prevent panic on nil connection during Close

* [client] Refactor connection handling to use explicit target tracking

Introduce `target` field to store the gRPC connection target directly, simplifying reconnections and ensuring consistent connection reuse logic.

* [client] Rename `isCancellation` to `isContextDone` and extend handling for `DeadlineExceeded`

Refactor error handling to include `DeadlineExceeded` scenarios alongside `Canceled`. Update related condition checks for consistency.

* [client] Add connection generation tracking to prevent stale reconnections

Introduce `connGen` to track connection generations and ensure that stale `recreateConnection` calls do not override newer connections. Update stream establishment and reconnection logic to incorporate generation validation.

* [client] Add backoff reset condition to prevent short-lived retry cycles

Refine backoff reset logic to ensure it only occurs for sufficiently long-lived stream connections, avoiding interference with `MaxElapsedTime`.

* [client] Introduce `minHealthyDuration` to refine backoff reset logic

Add `minHealthyDuration` constant to ensure stream retries only reset the backoff timer if the stream survives beyond a minimum duration. Prevents unhealthy, short-lived streams from interfering with `MaxElapsedTime`.

* [client] IPv6 friendly connection

parsedURL.Hostname() strips IPv6 brackets. For http://[::1]:443, this turns it into ::1:443, which is not a valid host:port target for gRPC. Additionally, fmt.Sprintf("%s:%s", hostname, port) produces a trailing colon when the URL has no explicit port—http://example.com becomes example.com:. Both cases break the initial dial and reconnect paths. Use parsedURL.Host directly instead.

* [client] Add `handlerStarted` channel to synchronize stream establishment in tests

Introduce `handlerStarted` channel in the test server to signal when the server-side handler begins, ensuring robust synchronization between client and server during stream establishment. Update relevant test cases to wait for this signal before proceeding.

* [client] Replace `receivedAcks` map with atomic counter and improve stream establishment sync in tests

Refactor acknowledgment tracking in tests to use an `atomic.Int32` counter instead of a map. Replace fixed sleep with robust synchronization by waiting on `handlerStarted` signal for stream establishment.

* [client] Extract `handleReceiveError` to simplify receive logic

Refactor error handling in `receive` to a dedicated `handleReceiveError` method. Streamlines the main logic and isolates error recovery, including backoff reset and connection recreation.

* [client] recreate gRPC ClientConn on every retry to prevent dual backoff

The flow client had two competing retry loops: our custom exponential
backoff and gRPC's internal subchannel reconnection. When establishStream
failed, the same ClientConn was reused, allowing gRPC's internal backoff
state to accumulate and control dial timing independently.

Changes:
- Consolidate error handling into handleRetryableError, which now
 handles context cancellation, permanent errors, backoff reset,
 and connection recreation in a single path
- Call recreateConnection on every retryable error so each retry
 gets a fresh ClientConn with no internal backoff state
- Remove connGen tracking since Receive is sequential and protected
 by a new receiving guard against concurrent calls
- Reduce RandomizationFactor from 1 to 0.5 to avoid near-zero
 backoff intervals
2026-04-13 10:42:24 +02:00
Zoltan Papp
7483fec048 Fix Android internet blackhole caused by stale route re-injection on TUN rebuild (#5865)
extraInitialRoutes() was meant to preserve only the fake IP route
(240.0.0.0/8) across TUN rebuilds, but it re-injected any initial
route missing from the current set. When the management server
advertised exit node routes (0.0.0.0/0) that were later filtered
by the route selector, extraInitialRoutes() re-added them, causing
the Android VPN to capture all traffic with no peer to handle it.

Store the fake IP route explicitly and append only that in notify(),
removing the overly broad initial route diffing.
2026-04-13 09:38:38 +02:00
crn4
de3cb06067 added proxy id to cluster api response 2026-03-31 00:28:19 +02:00
crn4
4fdc39c8f8 review comments 2026-03-24 15:37:31 +01:00
crn4
94149a9441 linter 2026-03-24 14:58:03 +01:00
crn4
38fd73fad6 merge main 2026-03-24 14:50:03 +01:00
crn4
9dd76b5a07 merge main 2026-03-24 14:20:03 +01:00
crn4
0b5380a7dc review comments 2026-03-24 13:32:38 +01:00
crn4
177171e437 change api 2026-03-19 21:49:04 +01:00
crn4
da57b0f276 rename byod to byop 2026-03-19 16:11:57 +01:00
crn4
26ba03f08e [proxy] feature: bring your own proxy 2026-03-19 01:02:46 +01:00
38 changed files with 2798 additions and 172 deletions

View File

@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
m.notifier.SetFakeIPRoute(fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)

View File

@@ -16,6 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoute *route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
n.fakeIPRoute = r
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
}
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
if n.fakeIPRoute != nil {
allRoutes = append(allRoutes, n.fakeIPRoute)
}
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {

View File

@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on iOS
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}

View File

@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -26,11 +25,22 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
var ErrClientClosed = errors.New("client is closed")
// minHealthyDuration is the minimum time a stream must survive before a failure
// resets the backoff timer. Streams that fail faster are considered unhealthy and
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
const minHealthyDuration = 5 * time.Second
type GRPCClient struct {
realClient proto.FlowServiceClient
clientConn *grpc.ClientConn
stream proto.FlowService_EventsClient
streamMu sync.Mutex
target string
opts []grpc.DialOption
closed bool // prevent creating conn in the middle of the Close
receiving bool // prevent concurrent Receive calls
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
}
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
@@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
)
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
target := parsedURL.Host
conn, err := grpc.NewClient(target, opts...)
if err != nil {
return nil, fmt.Errorf("creating new grpc client: %w", err)
}
@@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
return &GRPCClient{
realClient: proto.NewFlowServiceClient(conn),
clientConn: conn,
target: target,
opts: opts,
}, nil
}
func (c *GRPCClient) Close() error {
c.streamMu.Lock()
defer c.streamMu.Unlock()
c.mu.Lock()
c.closed = true
c.stream = nil
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
conn := c.clientConn
c.clientConn = nil
c.mu.Unlock()
if conn == nil {
return nil
}
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("close client connection: %w", err)
}
return nil
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.mu.Lock()
stream := c.stream
c.mu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
}
return nil
}
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
c.mu.Lock()
if c.receiving {
c.mu.Unlock()
return errors.New("concurrent Receive calls are not supported")
}
c.receiving = true
c.mu.Unlock()
defer func() {
c.mu.Lock()
c.receiving = false
c.mu.Unlock()
}()
backOff := defaultBackoff(ctx, interval)
operation := func() error {
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
}
stream, err := c.establishStream(ctx)
if err != nil {
log.Errorf("failed to establish flow stream, retrying: %v", err)
return c.handleRetryableError(err, time.Time{}, backOff)
}
streamStart := time.Now()
if err := c.receive(stream, msgHandler); err != nil {
log.Errorf("receive failed: %v", err)
return fmt.Errorf("receive: %w", err)
return c.handleRetryableError(err, streamStart, backOff)
}
return nil
}
@@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
return nil
}
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
if c.clientConn.GetState() == connectivity.Shutdown {
return errors.New("connection to flow receiver has been shut down")
// handleRetryableError resets the backoff timer if the stream was healthy long
// enough and recreates the underlying ClientConn so that gRPC's internal
// subchannel backoff does not accumulate and compete with our own retry timer.
// A zero streamStart means the stream was never established.
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
if isContextDone(err) {
return backoff.Permanent(err)
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return fmt.Errorf("create event stream: %w", err)
var permErr *backoff.PermanentError
if errors.As(err, &permErr) {
return err
}
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
// Reset the backoff so the next retry starts with a short delay instead of
// continuing the already-elapsed timer. Only do this if the stream was healthy
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
backOff.Reset()
}
if recreateErr := c.recreateConnection(); recreateErr != nil {
log.Errorf("recreate connection: %v", recreateErr)
return recreateErr
}
log.Infof("connection recreated, retrying stream")
return fmt.Errorf("retrying after error: %w", err)
}
func (c *GRPCClient) recreateConnection() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return backoff.Permanent(ErrClientClosed)
}
conn, err := grpc.NewClient(c.target, c.opts...)
if err != nil {
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
c.mu.Unlock()
return fmt.Errorf("create new connection: %w", err)
}
old := c.clientConn
c.clientConn = conn
c.realClient = proto.NewFlowServiceClient(conn)
c.stream = nil
c.mu.Unlock()
_ = old.Close()
return nil
}
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
cl := c.realClient
c.mu.Unlock()
// open stream outside the lock — blocking operation
stream, err := cl.Events(ctx)
if err != nil {
return nil, fmt.Errorf("create event stream: %w", err)
}
streamReady := false
defer func() {
if !streamReady {
_ = stream.CloseSend()
}
}()
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
return nil, fmt.Errorf("send initiator: %w", err)
}
if err = checkHeader(stream); err != nil {
return fmt.Errorf("check header: %w", err)
return nil, fmt.Errorf("check header: %w", err)
}
c.streamMu.Lock()
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
c.stream = stream
c.streamMu.Unlock()
c.mu.Unlock()
streamReady = true
return c.receive(stream, msgHandler)
return stream, nil
}
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
for {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive from stream: %w", err)
return err
}
if msg.IsInitiator {
@@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
RandomizationFactor: 0.5,
Multiplier: 1.7,
MaxInterval: interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
@@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
}, ctx)
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.streamMu.Lock()
stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
func isContextDone(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
if s, ok := status.FromError(err); ok {
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
}
return nil
return false
}

View File

@@ -2,8 +2,11 @@ package client_test
import (
"context"
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
@@ -11,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
@@ -18,21 +23,89 @@ import (
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
listener *connTrackListener
closeStream chan struct{} // signal server to close the stream
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
// connTrackListener wraps a net.Listener to track accepted connections
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
type connTrackListener struct {
net.Listener
mu sync.Mutex
conns []net.Conn
}
func (l *connTrackListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.mu.Lock()
l.conns = append(l.conns, c)
l.mu.Unlock()
return c, nil
}
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
// (error code 0x1) on every tracked connection. This produces the exact error:
//
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
//
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
//
// Length (3 bytes): 0x000004
// Type (1 byte): 0x03 (RST_STREAM)
// Flags (1 byte): 0x00
// Stream ID (4 bytes): target stream (must have bit 31 clear)
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
func (l *connTrackListener) connCount() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.conns)
}
func (l *connTrackListener) sendRSTStream(streamID uint32) {
l.mu.Lock()
defer l.mu.Unlock()
frame := make([]byte, 13) // 9-byte header + 4-byte payload
// Length = 4 (3 bytes, big-endian)
frame[0], frame[1], frame[2] = 0, 0, 4
// Type = RST_STREAM (0x03)
frame[3] = 0x03
// Flags = 0
frame[4] = 0x00
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
binary.BigEndian.PutUint32(frame[5:9], streamID)
// Error Code = PROTOCOL_ERROR (0x1)
binary.BigEndian.PutUint32(frame[9:13], 0x1)
for _, c := range l.conns {
_, _ = c.Write(frame)
}
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listener := &connTrackListener{Listener: rawListener}
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: rawListener.Addr().String(),
listener: listener,
closeStream: make(chan struct{}, 1),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
@@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer {
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
@@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
if err := stream.Send(ack); err != nil {
return err
}
case <-s.closeStream:
return status.Errorf(codes.Internal, "server closing stream")
case <-ctx.Done():
return ctx.Err()
}
@@ -110,16 +197,13 @@ func TestReceive(t *testing.T) {
assert.NoError(t, err, "failed to close flow")
})
receivedAcks := make(map[string]bool)
var ackCount atomic.Int32
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
if ackCount.Add(1) >= 3 {
close(receiveDone)
}
}
@@ -130,7 +214,11 @@ func TestReceive(t *testing.T) {
}
}()
time.Sleep(500 * time.Millisecond)
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
@@ -153,7 +241,7 @@ func TestReceive(t *testing.T) {
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, 3, len(receivedAcks))
assert.Equal(t, int32(3), ackCount.Load())
}
func TestReceive_ContextCancellation(t *testing.T) {
@@ -254,3 +342,195 @@ func TestSend(t *testing.T) {
t.Fatal("timeout waiting for ack to be received by flow")
}
}
func TestNewClient_PermanentClose(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
select {
case err := <-done:
require.ErrorIs(t, err, flow.ErrClientClosed)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
}
func TestNewClient_CloseVerify(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
closeDone := make(chan struct{}, 1)
go func() {
_ = client.Close()
closeDone <- struct{}{}
}()
select {
case err := <-done:
require.Error(t, err)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
select {
case <-closeDone:
return
case <-time.After(2 * time.Second):
t.Fatal("Close did not return — blocked in retry loop")
}
}
func TestClose_WhileReceiving(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx := context.Background() // no timeout — intentional
receiveDone := make(chan struct{})
go func() {
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
close(receiveDone)
}()
// Wait for the server-side handler to confirm the stream is established.
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
closeDone := make(chan struct{})
go func() {
_ = client.Close()
close(closeDone)
}()
select {
case <-closeDone:
// Close returned — good
case <-time.After(2 * time.Second):
t.Fatal("Close blocked forever — Receive stuck in retry loop")
}
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Fatal("Receive did not exit after Close")
}
}
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
// Track acks received before and after server-side stream close
var ackCount atomic.Int32
receivedFirst := make(chan struct{})
receivedAfterReconnect := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil
}
n := ackCount.Add(1)
if n == 1 {
close(receivedFirst)
}
if n == 2 {
close(receivedAfterReconnect)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
// Wait for stream to be established, then send first ack
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
select {
case <-receivedFirst:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for first ack")
}
// Snapshot connection count before injecting the fault.
connsBefore := server.listener.connCount()
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
// Stream ID 1 is the client's first stream (our Events bidi stream).
// This produces the exact error the client sees in production:
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
server.listener.sendRSTStream(1)
// Wait for the old Events() handler to fully exit so it can no longer
// drain s.acks and drop our injected ack on a broken stream.
select {
case <-server.handlerDone:
case <-time.After(5 * time.Second):
t.Fatal("old Events() handler did not exit after RST_STREAM")
}
require.Eventually(t, func() bool {
return server.listener.connCount() > connsBefore
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
select {
case <-receivedAfterReconnect:
// Client successfully reconnected and received ack after server-side stream close
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
}
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
}

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
}
@@ -70,8 +71,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -123,8 +124,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -270,7 +271,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -295,6 +296,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
if len(byopAddresses) > 0 {
return byopAddresses, nil
}
return m.proxyManager.GetActiveClusterAddresses(ctx)
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1

View File

@@ -0,0 +1,106 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}

View File

@@ -11,14 +11,19 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -13,12 +13,18 @@ import (
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
}
// Manager handles all proxy operations
@@ -42,7 +48,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) error {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -52,9 +58,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Status: proxy.StatusConnected,
Capabilities: caps,
}
@@ -73,16 +80,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
@@ -95,7 +94,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
@@ -107,7 +106,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddre
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -139,10 +138,44 @@ func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteProxy(ctx context.Context, proxyID string) error {
if err := m.store.DeleteProxy(ctx, proxyID); err != nil {
log.WithContext(ctx).Errorf("failed to delete proxy %s: %v", proxyID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,331 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID string) error
updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteProxyFunc func(ctx context.Context, proxyID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteProxy(ctx context.Context, proxyID string) error {
if m.deleteProxyFunc != nil {
return m.deleteProxyFunc(ctx, proxyID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteProxy(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedID string
s := &mockStore{
deleteProxyFunc: func(_ context.Context, proxyID string) error {
deletedID = proxyID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteProxy(context.Background(), "proxy-1")
require.NoError(t, err)
assert.Equal(t, "proxy-1", deletedID)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteProxyFunc: func(_ context.Context, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteProxy(context.Background(), "proxy-1")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -79,17 +79,17 @@ func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
}
// Disconnect mocks base method.
@@ -121,7 +121,19 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
@@ -150,6 +162,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAdd
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteProxy mocks base method.
func (m *MockManager) DeleteProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteProxy indicates an expected call of DeleteProxy.
func (mr *MockManagerMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockManager)(nil).DeleteProxy), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller

View File

@@ -1,6 +1,16 @@
package proxy
import "time"
import (
"errors"
"time"
)
var ErrAccountProxyAlreadyExists = errors.New("account already has a registered proxy")
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
@@ -18,6 +28,7 @@ type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);uniqueIndex:idx_proxy_account_id_unique"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
@@ -33,6 +44,7 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
ID string
Address string
ConnectedProxies int
}

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -28,4 +28,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -138,6 +138,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -195,6 +195,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
})

View File

@@ -984,6 +984,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil
}
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {

View File

@@ -426,7 +426,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -707,7 +707,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1132,7 +1132,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -169,7 +169,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
@@ -47,6 +48,11 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ProxyTokenChecker checks whether a proxy access token is still valid.
type ProxyTokenChecker interface {
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -75,6 +81,9 @@ type ProxyServiceServer struct {
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
// Checker for proxy access token validity
tokenChecker ProxyTokenChecker
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
@@ -90,6 +99,8 @@ const pkceVerifierTTL = 10 * time.Minute
type proxyConnection struct {
proxyID string
address string
accountID *string
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
@@ -97,8 +108,19 @@ type proxyConnection struct {
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
token := GetProxyTokenFromContext(ctx)
if token == nil || token.AccountID == nil {
return nil
}
if requestAccountID == "" || *token.AccountID != requestAccountID {
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
}
return nil
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -108,6 +130,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
@@ -166,10 +189,48 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
existingProxy, err := s.proxyManager.GetAccountProxy(ctx, *accountID)
if err != nil {
if s, ok := nbstatus.FromError(err); ok && s.ErrorType == nbstatus.NotFound {
log.WithContext(ctx).Debugf("no existing BYOP proxy for account %s", *accountID)
} else {
return status.Errorf(codes.Internal, "failed to check existing proxy: %v", err)
}
}
if existingProxy != nil && existingProxy.ID != proxyID {
if existingProxy.Status == proxy.StatusConnected {
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
}
if err := s.proxyManager.DeleteProxy(ctx, existingProxy.ID); err != nil {
log.WithContext(ctx).Warnf("failed to cleanup disconnected proxy %s: %v", existingProxy.ID, err)
}
}
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
@@ -177,12 +238,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
@@ -190,19 +245,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
RequireSubdomain: c.RequireSubdomain,
}
}
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.Delete(proxyID)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID, caps); err != nil {
if accountID != nil {
cancel()
if errors.Is(err, proxy.ErrAccountProxyAlreadyExists) {
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
}
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
@@ -227,7 +291,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
go s.heartbeat(connCtx, conn, peerInfo)
select {
case err := <-errChan:
@@ -237,16 +301,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, ipAddress string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
if err := s.proxyManager.Heartbeat(ctx, conn.proxyID, conn.address, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err)
}
if conn.tokenID != "" && s.tokenChecker != nil {
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
if err != nil {
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
continue
}
if !valid {
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
conn.cancel()
return
}
}
case <-ctx.Done():
return
@@ -254,8 +330,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddr
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
@@ -288,7 +362,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
var services []*rpservice.Service
var err error
if conn.accountID != nil {
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
} else {
services, err = s.serviceManager.GetGlobalServices(ctx)
}
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
@@ -317,8 +397,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
return mappings, nil
}
// isProxyAddressValid validates a proxy address
// isProxyAddressValid validates a proxy address (domain name or IP address)
func isProxyAddressValid(addr string) bool {
if addr == "" {
return false
}
if net.ParseIP(addr) != nil {
return true
}
_, err := domain.ValidateDomains([]string{addr})
return err == nil
}
@@ -342,6 +428,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
return nil, err
}
fields := log.Fields{
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
@@ -379,11 +469,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
connUpdate := update
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
if len(filtered) == 0 {
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
return true
}
@@ -397,6 +508,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
})
}
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
conn.cancel()
s.connectedProxies.Delete(proxyID)
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
}
}
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
var filtered []*proto.ProxyMapping
for _, m := range mappings {
if m.AccountId == accountID {
filtered = append(filtered, m)
}
}
return filtered
}
// GetConnectedProxies returns a list of connected proxy IDs
func (s *ProxyServiceServer) GetConnectedProxies() []string {
var proxies []string
@@ -465,6 +596,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue
}
conn := connVal.(*proxyConnection)
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
continue
}
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
@@ -548,6 +682,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
@@ -667,6 +805,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
@@ -737,6 +879,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
@@ -791,6 +937,10 @@ func strPtr(s string) *string {
}
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
@@ -919,21 +1069,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.serviceManager.GetGlobalServices(ctx)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
}
if service.SessionPrivateKey == "" {
@@ -1031,6 +1169,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
@@ -1114,18 +1256,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
return s.serviceManager.GetServiceByDomain(ctx, domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {

View File

@@ -0,0 +1,29 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsProxyAddressValid(t *testing.T) {
tests := []struct {
name string
addr string
valid bool
}{
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
{name: "valid IPv6", addr: "::1", valid: true},
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
{name: "empty string", addr: "", valid: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
})
}
}

View File

@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}

View File

@@ -91,6 +91,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {
return svc, nil
}
}
}
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -11,8 +11,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -315,6 +318,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format")
}
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)

View File

@@ -45,7 +45,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -321,13 +321,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error {
return nil
}
@@ -335,7 +339,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil
}
@@ -343,6 +347,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -351,6 +359,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil
}
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}

View File

@@ -3142,7 +3142,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -176,6 +177,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)

View File

@@ -215,6 +215,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
nil,
usersManager,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})
@@ -434,6 +435,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}

View File

@@ -108,7 +108,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -237,7 +237,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -17,6 +17,7 @@ import (
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -4497,6 +4498,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
return nil
}
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Where("account_id = ?", accountID).Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
}
return tokens, nil
}
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
if err != nil {
return false, err
}
return token.IsValid(), nil
}
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
}
return &token, nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
@@ -5434,18 +5476,49 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
if isUniqueConstraintError(result.Error) {
return proxy.ErrAccountProxyAlreadyExists
}
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
func isUniqueConstraintError(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return true
}
errStr := err.Error()
return strings.Contains(errStr, "UNIQUE constraint") ||
strings.Contains(errStr, "duplicate key") ||
strings.Contains(errStr, "Duplicate entry") ||
strings.Contains(errStr, "Error 1062")
}
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
now := time.Now()
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ?", proxyID).
Updates(map[string]interface{}{
"status": proxy.StatusDisconnected,
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
return nil
}
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
Update("last_seen", now)
if result.Error != nil {
@@ -5477,7 +5550,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
@@ -5489,13 +5562,72 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
return addresses, nil
}
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
}
return addresses, nil
}
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
var p proxy.Proxy
result := s.db.WithContext(ctx).Where("account_id = ?", accountID).Take(&p)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy not found for account")
}
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
}
return &p, nil
}
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
var count int64
result := s.db.WithContext(ctx).Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
}
return count, nil
}
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
Count(&count)
if result.Error != nil {
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
}
return count > 0, nil
}
func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{})
if result.Error != nil {
return status.Errorf(status.Internal, "delete proxy: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy not found")
}
return nil
}
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
var clusters []proxy.Cluster
result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Group("cluster_address").
Scan(&clusters)

View File

@@ -114,6 +114,9 @@ type Store interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
@@ -284,12 +287,18 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)

View File

@@ -165,6 +165,21 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// CountProxiesByAccountID mocks base method.
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -1287,7 +1302,19 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
}
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
@@ -1296,7 +1323,6 @@ func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster
return ret0, ret1
}
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
@@ -1346,6 +1372,51 @@ func (mr *MockStoreMockRecorder) GetAllProxyAccessTokens(ctx, lockStrength inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxyAccessTokens", reflect.TypeOf((*MockStore)(nil).GetAllProxyAccessTokens), ctx, lockStrength)
}
// GetProxyAccessTokensByAccountID mocks base method.
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
}
// GetProxyAccessTokenByID mocks base method.
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
ret0, _ := ret[0].(*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
}
// IsProxyAccessTokenValid mocks base method.
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
}
// GetAnyAccountID mocks base method.
func (m *MockStore) GetAnyAccountID(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -1944,6 +2015,50 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
}
// GetProxyByAccountID mocks base method.
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
ret0, _ := ret[0].(*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
}
// IsClusterAddressConflicting mocks base method.
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
}
// DeleteProxy mocks base method.
func (m *MockStore) DeleteProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteProxy indicates an expected call of DeleteProxy.
func (mr *MockStoreMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockStore)(nil).DeleteProxy), ctx, proxyID)
}
// GetResourceGroups mocks base method.
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -2786,6 +2901,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,408 @@
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
type byopTestSetup struct {
store store.Store
proxyService *nbgrpc.ProxyServiceServer
grpcServer *grpc.Server
grpcAddr string
cleanup func()
accountA string
accountB string
accountAToken types.PlainProxyToken
accountBToken types.PlainProxyToken
accountACluster string
accountBCluster string
}
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
accountAID := "byop-account-a"
accountBID := "byop-account-b"
for _, acc := range []*types.Account{
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
} {
require.NoError(t, testStore.SaveAccount(ctx, acc))
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
clusterA := "byop-a.proxy.test"
clusterB := "byop-b.proxy.test"
services := []*service.Service{
{
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
},
}
for _, svc := range services {
require.NoError(t, testStore.CreateService(ctx, svc))
}
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
meter := noop.NewMeterProvider().Meter("test")
realProxyManager, err := proxymanager.NewManager(testStore, meter)
require.NoError(t, err)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: "https://fake-issuer.example.com",
ClientID: "test-client",
HMACKey: []byte("test-hmac-key"),
}
usersManager := users.NewManager(testStore)
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,
realProxyManager,
nil,
)
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
proto.RegisterProxyServiceServer(grpcServer, proxyService)
go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("gRPC server error: %v", err)
}
}()
return &byopTestSetup{
store: testStore,
proxyService: proxyService,
grpcServer: grpcServer,
grpcAddr: lis.Addr().String(),
cleanup: func() {
grpcServer.GracefulStop()
authClose()
storeCleanup()
},
accountA: accountAID,
accountB: accountBID,
accountAToken: tokenA.PlainToken,
accountBToken: tokenB.PlainToken,
accountACluster: clusterA,
accountBCluster: clusterB,
}
}
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
md := metadata.Pairs("authorization", "Bearer "+string(token))
return metadata.NewOutgoingContext(ctx, md)
}
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
t.Helper()
var mappings []*proto.ProxyMapping
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
for _, m := range mappings {
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
}
ids := map[string]bool{}
for _, m := range mappings {
ids[m.GetId()] = true
}
assert.True(t, ids["svc-a1"], "should contain svc-a1")
assert.True(t, ids["svc-a2"], "should contain svc-a2")
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
}
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b",
Version: "test-v1",
Address: setup.accountBCluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
assert.Equal(t, "svc-b1", mappings[0].GetId())
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
}
func TestIntegration_BYOPProxy_LimitOnePerAccount(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-first",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_ = receiveBYOPMappings(t, stream1)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-second",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_, err = stream2.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.ResourceExhausted, st.Code(), "second BYOP proxy should be rejected with ResourceExhausted")
t.Logf("expected rejection: %s", st.Message())
}
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-cluster",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_ = receiveBYOPMappings(t, stream1)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b-conflict",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_, err = stream2.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
t.Logf("expected rejection: %s", st.Message())
}
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
proxyID := "byop-proxy-reconnect"
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
firstMappings := receiveBYOPMappings(t, stream1)
cancel1()
time.Sleep(200 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
secondMappings := receiveBYOPMappings(t, stream2)
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
firstIDs := map[string]bool{}
for _, m := range firstMappings {
firstIDs[m.GetId()] = true
}
for _, m := range secondMappings {
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
}
}
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "no-auth-proxy",
Version: "test-v1",
Address: "some.cluster.io",
})
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@@ -139,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
nil,
usersManager,
proxyManager,
nil,
)
// Use store-backed service manager
@@ -200,7 +202,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *nbproxy.Capabilities) error {
return nil
}
@@ -216,6 +218,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
return nil, nil
}
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
return nil, nil
}
@@ -232,6 +238,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
return nil
}
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
@@ -331,6 +353,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}

View File

@@ -3310,10 +3310,64 @@ components:
example: false
required:
- enabled
ProxyTokenRequest:
type: object
properties:
name:
type: string
description: Human-readable token name
example: "my-proxy-token"
expires_in:
type: integer
minimum: 0
description: Token expiration in seconds (0 = never expires)
example: 0
required:
- name
ProxyToken:
type: object
properties:
id:
type: string
name:
type: string
expires_at:
type: string
format: date-time
created_at:
type: string
format: date-time
last_used:
type: string
format: date-time
revoked:
type: boolean
required:
- id
- name
- created_at
- revoked
ProxyTokenCreated:
type: object
description: Returned on creation — plain_token is shown only once
allOf:
- $ref: '#/components/schemas/ProxyToken'
- type: object
properties:
plain_token:
type: string
description: The plain text token (shown only once)
example: "nbx_abc123..."
required:
- plain_token
ProxyCluster:
type: object
description: A proxy cluster represents a group of proxy nodes serving the same address
properties:
id:
type: string
description: Unique identifier of a proxy in this cluster
example: "chlfq4q5r8kc73b0qjpg"
address:
type: string
description: Cluster address used for CNAME targets
@@ -3322,9 +3376,15 @@ components:
type: integer
description: Number of proxy nodes connected in this cluster
example: 3
self_hosted:
type: boolean
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
example: false
required:
- id
- address
- connected_proxies
- self_hosted
ReverseProxyDomainType:
type: string
description: Type of Reverse Proxy Domain
@@ -11300,6 +11360,111 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/clusters/{clusterId}:
delete:
summary: Delete a self-hosted proxy cluster
description: Removes a self-hosted (BYOP) proxy cluster and disconnects it. Only self-hosted clusters can be deleted.
tags: [ Services ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: clusterId
required: true
schema:
type: string
description: The unique identifier of the proxy cluster
responses:
'200':
description: Proxy cluster deleted successfully
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens:
get:
summary: List Proxy Tokens
description: Returns all proxy access tokens for the account
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of proxy tokens
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/ProxyToken'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Proxy Token
description: Generate an account-scoped proxy access token for self-hosted proxy registration
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenRequest'
responses:
'200':
description: Proxy token created (plain token shown once)
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenCreated'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens/{tokenId}:
delete:
summary: Revoke a Proxy Token
description: Revoke an account-scoped proxy access token
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: tokenId
required: true
schema:
type: string
description: The unique identifier of the proxy token
responses:
'200':
description: Token revoked
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/services:
get:
summary: List all Services

View File

@@ -3731,11 +3731,49 @@ type ProxyAccessLogsResponse struct {
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
type ProxyCluster struct {
// Id Unique identifier of a proxy in this cluster
Id string `json:"id"`
// Address Cluster address used for CNAME targets
Address string `json:"address"`
// ConnectedProxies Number of proxy nodes connected in this cluster
ConnectedProxies int `json:"connected_proxies"`
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
SelfHosted bool `json:"self_hosted"`
}
// ProxyToken defines model for ProxyToken.
type ProxyToken struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
Revoked bool `json:"revoked"`
}
// ProxyTokenCreated defines model for ProxyTokenCreated.
type ProxyTokenCreated struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
// PlainToken The plain text token (shown only once)
PlainToken string `json:"plain_token"`
Revoked bool `json:"revoked"`
}
// ProxyTokenRequest defines model for ProxyTokenRequest.
type ProxyTokenRequest struct {
// ExpiresIn Token expiration in seconds (0 = never expires)
ExpiresIn *int `json:"expires_in,omitempty"`
// Name Human-readable token name
Name string `json:"name"`
}
// Resource defines model for Resource.
@@ -5094,6 +5132,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest