mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 11:46:40 +00:00
Compare commits
54 Commits
deploy/pro
...
test/proxy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
373f014aea | ||
|
|
2df3fb959b | ||
|
|
14b0e9462b | ||
|
|
8db71b545e | ||
|
|
7c3532d8e5 | ||
|
|
1aa1eef2c5 | ||
|
|
d9418ddc1e | ||
|
|
1b4c831976 | ||
|
|
a19611d8e0 | ||
|
|
9ab6138040 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
c2fec57c0f | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 | ||
|
|
8e7b016be2 | ||
|
|
9e01ea7aae | ||
|
|
cfc7ec8bb9 | ||
|
|
b3bbc0e5c6 | ||
|
|
d7c8e37ff4 | ||
|
|
05b66e73bc | ||
|
|
01ceedac89 | ||
|
|
403babd433 | ||
|
|
47133031e5 | ||
|
|
82da606886 | ||
|
|
bbe5ae2145 | ||
|
|
0b21498b39 | ||
|
|
0ca59535f1 | ||
|
|
59c77d0658 | ||
|
|
333e045099 | ||
|
|
c2c4d9d336 | ||
|
|
9a6a72e88e | ||
|
|
afe6d9fca4 | ||
|
|
ef82905526 | ||
|
|
d18747e846 | ||
|
|
f341d69314 | ||
|
|
327142837c | ||
|
|
f8c0321aee | ||
|
|
89115ff76a | ||
|
|
63c83aa8d2 | ||
|
|
37f025c966 | ||
|
|
4a54f0d670 | ||
|
|
98890a29e3 | ||
|
|
9d123ec059 | ||
|
|
5d171f181a | ||
|
|
22f878b3b7 | ||
|
|
44ef1a18dd | ||
|
|
2b98dc4e52 | ||
|
|
2a26cb4567 |
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Community Support
|
||||||
|
url: https://forum.netbird.io/
|
||||||
|
about: Community support forum
|
||||||
|
- name: Cloud Support
|
||||||
|
url: https://docs.netbird.io/help/report-bug-issues
|
||||||
|
about: Contact us for support
|
||||||
|
- name: Client/Connection Troubleshooting
|
||||||
|
url: https://docs.netbird.io/help/troubleshooting-client
|
||||||
|
about: See our client troubleshooting guide for help addressing common issues
|
||||||
|
- name: Self-host Troubleshooting
|
||||||
|
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||||
|
about: See our self-host troubleshooting guide for help addressing common issues
|
||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: PR Title Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, edited, synchronize, reopened]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-title:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Validate PR title prefix
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title;
|
||||||
|
const allowedTags = [
|
||||||
|
'management',
|
||||||
|
'client',
|
||||||
|
'signal',
|
||||||
|
'proxy',
|
||||||
|
'relay',
|
||||||
|
'misc',
|
||||||
|
'infrastructure',
|
||||||
|
'self-hosted',
|
||||||
|
'doc',
|
||||||
|
];
|
||||||
|
|
||||||
|
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||||
|
const match = title.match(pattern);
|
||||||
|
|
||||||
|
if (!match) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title must start with a tag in brackets.\n` +
|
||||||
|
`Example: [client] fix something\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||||
|
|
||||||
|
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||||
|
if (invalid.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.23.2
|
FROM alpine:3.23.3
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
|
|||||||
194
client/cmd/expose.go
Normal file
194
client/cmd/expose.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||||
|
|
||||||
|
var (
|
||||||
|
exposePin string
|
||||||
|
exposePassword string
|
||||||
|
exposeUserGroups []string
|
||||||
|
exposeDomain string
|
||||||
|
exposeNamePrefix string
|
||||||
|
exposeProtocol string
|
||||||
|
)
|
||||||
|
|
||||||
|
var exposeCmd = &cobra.Command{
|
||||||
|
Use: "expose <port>",
|
||||||
|
Short: "Expose a local port via the NetBird reverse proxy",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
Example: "netbird expose --with-password safe-pass 8080",
|
||||||
|
RunE: exposeFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||||
|
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||||
|
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||||
|
}
|
||||||
|
if port == 0 || port > 65535 {
|
||||||
|
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isProtocolValid(exposeProtocol) {
|
||||||
|
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||||
|
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||||
|
return 0, fmt.Errorf("password cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||||
|
return 0, fmt.Errorf("user groups cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProtocolValid(exposeProtocol string) bool {
|
||||||
|
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
|
||||||
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
log.Errorf("failed initializing log %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = false
|
||||||
|
|
||||||
|
port, err := validateExposeFlags(cmd, args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = true
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigCh
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close daemon connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
protocol, err := toExposeProtocol(exposeProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
|
||||||
|
Port: uint32(port),
|
||||||
|
Protocol: protocol,
|
||||||
|
Pin: exposePin,
|
||||||
|
Password: exposePassword,
|
||||||
|
UserGroups: exposeUserGroups,
|
||||||
|
Domain: exposeDomain,
|
||||||
|
NamePrefix: exposeNamePrefix,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("expose service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return waitForExposeEvents(cmd, ctx, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||||
|
switch strings.ToLower(exposeProtocol) {
|
||||||
|
case "http":
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||||
|
case "https":
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||||
|
event, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("receive expose event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch e := event.Event.(type) {
|
||||||
|
case *proto.ExposeServiceEvent_Ready:
|
||||||
|
cmd.Println("Service exposed successfully!")
|
||||||
|
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
|
||||||
|
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
|
||||||
|
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
|
||||||
|
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||||
|
cmd.Printf(" Port: %d\n", port)
|
||||||
|
cmd.Println()
|
||||||
|
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||||
|
for {
|
||||||
|
_, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
cmd.Println("\nService stopped.")
|
||||||
|
//nolint:nilerr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("stream error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,6 +81,15 @@ var (
|
|||||||
Short: "",
|
Short: "",
|
||||||
Long: "",
|
Long: "",
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(cmd.Root())
|
||||||
|
|
||||||
|
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||||
|
if !isServiceCmd(cmd) {
|
||||||
|
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,6 +154,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
rootCmd.AddCommand(profileCmd)
|
rootCmd.AddCommand(profileCmd)
|
||||||
|
rootCmd.AddCommand(exposeCmd)
|
||||||
|
|
||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
|||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||||
|
func isServiceCmd(cmd *cobra.Command) bool {
|
||||||
|
for c := cmd; c != nil; c = c.Parent() {
|
||||||
|
if c.Name() == "service" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,20 +5,18 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func openUAPI(deviceName string) (net.Listener, error) {
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to open uapi socket: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
_ = uapiSock.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
|||||||
return wgCfg
|
return wgCfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||||
|
return &WGUSPConfigurer{
|
||||||
|
device: device,
|
||||||
|
deviceName: deviceName,
|
||||||
|
activityRecorder: activityRecorder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||||
log.Debugf("adding Wireguard private key")
|
log.Debugf("adding Wireguard private key")
|
||||||
key, err := wgtypes.ParseKey(privateKey)
|
key, err := wgtypes.ParseKey(privateKey)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cErr := tunIface.Close(); cErr != nil {
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
|||||||
@@ -331,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil {
|
if runningChan != nil {
|
||||||
|
select {
|
||||||
|
case <-runningChan:
|
||||||
|
default:
|
||||||
close(runningChan)
|
close(runningChan)
|
||||||
runningChan = nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|||||||
60
client/internal/daemonaddr/resolve.go
Normal file
60
client/internal/daemonaddr/resolve.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var scanDir = "/var/run/netbird"
|
||||||
|
|
||||||
|
// setScanDir overrides the scan directory (used by tests).
|
||||||
|
func setScanDir(dir string) {
|
||||||
|
scanDir = dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||||
|
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||||
|
// mismatch between the netbird@.service template (which places the socket under
|
||||||
|
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
if !strings.HasPrefix(addr, "unix://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(scanDir)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var found []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(e.Name(), ".sock") {
|
||||||
|
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(found) {
|
||||||
|
case 1:
|
||||||
|
resolved := "unix://" + found[0]
|
||||||
|
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||||
|
return resolved
|
||||||
|
case 0:
|
||||||
|
return addr
|
||||||
|
default:
|
||||||
|
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
}
|
||||||
8
client/internal/daemonaddr/resolve_stub.go
Normal file
8
client/internal/daemonaddr/resolve_stub.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build windows || ios || android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
121
client/internal/daemonaddr/resolve_test.go
Normal file
121
client/internal/daemonaddr/resolve_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSockFile creates a regular file with a .sock extension.
|
||||||
|
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||||
|
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||||
|
func createSockFile(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
sock := filepath.Join(tmp, "netbird.sock")
|
||||||
|
createSockFile(t, sock)
|
||||||
|
|
||||||
|
addr := "unix://" + sock
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
// Default socket does not exist
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
// Create a scan dir with one socket
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
instanceSock := filepath.Join(sd, "main.sock")
|
||||||
|
createSockFile(t, instanceSock)
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
expected := "unix://" + instanceSock
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||||
|
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||||
|
addr := "tcp://127.0.0.1:41731"
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||||
return ruleIndex, nil
|
return ruleIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverDomains.Flow != "" {
|
// Flow receiver domain is intentionally excluded from caching.
|
||||||
domains = append(domains, serverDomains.Flow)
|
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||||
}
|
// causes TLS certificate verification failures on reconnect.
|
||||||
|
|
||||||
for _, stun := range serverDomains.Stuns {
|
for _, stun := range serverDomains.Stuns {
|
||||||
if stun != "" {
|
if stun != "" {
|
||||||
|
|||||||
@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||||
|
|
||||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||||
|
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||||
partialDomains := dnsconfig.ServerDomains{
|
partialDomains := dnsconfig.ServerDomains{
|
||||||
Flow: "github.com",
|
Flow: "github.com",
|
||||||
}
|
}
|
||||||
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||||
|
|
||||||
finalDomains := resolver.GetCachedDomains()
|
finalDomains := resolver.GetCachedDomains()
|
||||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||||
|
|
||||||
domainStrings := make([]string, len(finalDomains))
|
domainStrings := make([]string, len(finalDomains))
|
||||||
for i, d := range finalDomains {
|
for i, d := range finalDomains {
|
||||||
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
assert.Contains(t, domainStrings, "example.org")
|
assert.Contains(t, domainStrings, "example.org")
|
||||||
assert.Contains(t, domainStrings, "google.com")
|
assert.Contains(t, domainStrings, "google.com")
|
||||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||||
assert.Contains(t, domainStrings, "github.com")
|
assert.NotContains(t, domainStrings, "github.com")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
return fmt.Errorf("upstream check call error")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, exponentialBackOff)
|
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||||
|
} else {
|
||||||
|
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
@@ -53,13 +54,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
"github.com/netbirdio/netbird/client/jobexec"
|
"github.com/netbirdio/netbird/client/jobexec"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -75,7 +74,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
connInitLimit = 200
|
|
||||||
disableAutoUpdate = "disabled"
|
disableAutoUpdate = "disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,7 +206,6 @@ type Engine struct {
|
|||||||
syncRespMux sync.RWMutex
|
syncRespMux sync.RWMutex
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
// auto-update
|
// auto-update
|
||||||
@@ -224,6 +221,8 @@ type Engine struct {
|
|||||||
|
|
||||||
jobExecutor *jobexec.Executor
|
jobExecutor *jobexec.Executor
|
||||||
jobExecutorWG sync.WaitGroup
|
jobExecutorWG sync.WaitGroup
|
||||||
|
|
||||||
|
exposeManager *expose.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -266,7 +265,6 @@ func NewEngine(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
jobExecutor: jobexec.NewExecutor(),
|
jobExecutor: jobexec.NewExecutor(),
|
||||||
}
|
}
|
||||||
@@ -419,6 +417,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||||
|
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -801,7 +800,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
|
|
||||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||||
|
|
||||||
// Stop and cleanup if disabled
|
// stop and cleanup if disabled
|
||||||
if e.updateManager != nil && disabled {
|
if e.updateManager != nil && disabled {
|
||||||
log.Infof("auto-update is disabled, stopping update manager")
|
log.Infof("auto-update is disabled, stopping update manager")
|
||||||
e.updateManager.Stop()
|
e.updateManager.Stop()
|
||||||
@@ -1539,7 +1538,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||||
RelayManager: e.relayManager,
|
RelayManager: e.relayManager,
|
||||||
SrWatcher: e.srWatcher,
|
SrWatcher: e.srWatcher,
|
||||||
Semaphore: e.connSemaphore,
|
|
||||||
}
|
}
|
||||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1824,11 +1822,18 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
return e.routeManager
|
return e.routeManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirewallManager returns the firewall manager
|
// GetFirewallManager returns the firewall manager.
|
||||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetExposeManager returns the expose session manager.
|
||||||
|
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
return e.exposeManager
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
95
client/internal/expose/manager.go
Normal file
95
client/internal/expose/manager.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const renewTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// Response holds the response from exposing a service.
|
||||||
|
type Response struct {
|
||||||
|
ServiceName string
|
||||||
|
ServiceURL string
|
||||||
|
Domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
NamePrefix string
|
||||||
|
Domain string
|
||||||
|
Port uint16
|
||||||
|
Protocol int
|
||||||
|
Pin string
|
||||||
|
Password string
|
||||||
|
UserGroups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagementClient interface {
|
||||||
|
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||||
|
RenewExpose(ctx context.Context, domain string) error
|
||||||
|
StopExpose(ctx context.Context, domain string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles expose session lifecycle via the management client.
|
||||||
|
type Manager struct {
|
||||||
|
mgmClient ManagementClient
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new expose Manager using the given management client.
|
||||||
|
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||||
|
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expose creates a new expose session via the management server.
|
||||||
|
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||||
|
log.Infof("exposing service on port %d", req.Port)
|
||||||
|
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("expose session created for %s", resp.Domain)
|
||||||
|
|
||||||
|
return fromClientExposeResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
defer m.stop(domain)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.renew(ctx, domain); err != nil {
|
||||||
|
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renew extends the TTL of an active expose session.
|
||||||
|
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||||
|
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop terminates an active expose session.
|
||||||
|
func (m *Manager) stop(domain string) {
|
||||||
|
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
client/internal/expose/manager_test.go
Normal file
95
client/internal/expose/manager_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager_Expose_Success(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||||
|
return &mgm.ExposeResponse{
|
||||||
|
ServiceName: "my-service",
|
||||||
|
ServiceURL: "https://my-service.example.com",
|
||||||
|
Domain: "my-service.example.com",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||||
|
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||||
|
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Expose_Error(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||||
|
return nil, errors.New("permission denied")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Renew_Success(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||||
|
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
err := m.renew(context.Background(), "my-service.example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Renew_Timeout(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||||
|
return ctx.Err()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(ctx, mock)
|
||||||
|
err := m.renew(ctx, "my-service.example.com")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRequest(t *testing.T) {
|
||||||
|
req := &daemonProto.ExposeServiceRequest{
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||||
|
Pin: "123456",
|
||||||
|
Password: "secret",
|
||||||
|
UserGroups: []string{"group1", "group2"},
|
||||||
|
Domain: "custom.example.com",
|
||||||
|
NamePrefix: "my-prefix",
|
||||||
|
}
|
||||||
|
|
||||||
|
exposeReq := NewRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||||
|
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||||
|
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||||
|
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||||
|
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||||
|
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||||
|
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||||
|
}
|
||||||
39
client/internal/expose/request.go
Normal file
39
client/internal/expose/request.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||||
|
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||||
|
return &Request{
|
||||||
|
Port: uint16(req.Port),
|
||||||
|
Protocol: int(req.Protocol),
|
||||||
|
Pin: req.Pin,
|
||||||
|
Password: req.Password,
|
||||||
|
UserGroups: req.UserGroups,
|
||||||
|
Domain: req.Domain,
|
||||||
|
NamePrefix: req.NamePrefix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||||
|
return mgm.ExposeRequest{
|
||||||
|
NamePrefix: req.NamePrefix,
|
||||||
|
Domain: req.Domain,
|
||||||
|
Port: req.Port,
|
||||||
|
Protocol: req.Protocol,
|
||||||
|
Pin: req.Pin,
|
||||||
|
Password: req.Password,
|
||||||
|
UserGroups: req.UserGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||||
|
return &Response{
|
||||||
|
ServiceName: response.ServiceName,
|
||||||
|
Domain: response.Domain,
|
||||||
|
ServiceURL: response.ServiceURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,18 +22,24 @@ func prepareFd() (int, error) {
|
|||||||
|
|
||||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
for {
|
for {
|
||||||
select {
|
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
|
||||||
case <-ctx.Done():
|
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
|
||||||
return ctx.Err()
|
if err := waitReadable(ctx, fd); err != nil {
|
||||||
default:
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
buf := make([]byte, 2048)
|
buf := make([]byte, 2048)
|
||||||
n, err := unix.Read(fd, buf)
|
n, err := unix.Read(fd, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
|
||||||
|
return fmt.Errorf("routing socket closed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read routing socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if n < unix.SizeofRtMsghdr {
|
if n < unix.SizeofRtMsghdr {
|
||||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
continue
|
continue
|
||||||
@@ -70,7 +76,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||||
@@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
|||||||
|
|
||||||
return systemops.MsgToRoute(msg)
|
return systemops.MsgToRoute(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||||
|
func waitReadable(ctx context.Context, fd int) error {
|
||||||
|
var fdset unix.FdSet
|
||||||
|
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
|
||||||
|
return fmt.Errorf("fd %d out of range for FdSet", fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fdset = unix.FdSet{}
|
||||||
|
fdset.Set(fd)
|
||||||
|
// Use a 1-second timeout so we can re-check ctx periodically.
|
||||||
|
tv := unix.Timeval{Sec: 1}
|
||||||
|
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EINTR) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("select on routing socket: %w", err)
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// timeout — loop back and re-check ctx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package peer
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -25,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceDependencies struct {
|
type ServiceDependencies struct {
|
||||||
@@ -34,7 +32,6 @@ type ServiceDependencies struct {
|
|||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
RelayManager *relayClient.Manager
|
RelayManager *relayClient.Manager
|
||||||
SrWatcher *guard.SRWatcher
|
SrWatcher *guard.SRWatcher
|
||||||
Semaphore *semaphoregroup.SemaphoreGroup
|
|
||||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +109,6 @@ type Conn struct {
|
|||||||
handshaker *Handshaker
|
handshaker *Handshaker
|
||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
semaphore *semaphoregroup.SemaphoreGroup
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// debug purpose
|
// debug purpose
|
||||||
@@ -139,7 +135,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
iFaceDiscover: services.IFaceDiscover,
|
iFaceDiscover: services.IFaceDiscover,
|
||||||
relayManager: services.RelayManager,
|
relayManager: services.RelayManager,
|
||||||
srWatcher: services.SrWatcher,
|
srWatcher: services.SrWatcher,
|
||||||
semaphore: services.Semaphore,
|
|
||||||
statusRelay: worker.NewAtomicStatus(),
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
dumpState: dumpState,
|
dumpState: dumpState,
|
||||||
@@ -154,15 +149,10 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.opened {
|
if conn.opened {
|
||||||
conn.semaphore.Done()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,7 +163,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.semaphore.Done()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
conn.workerICE = workerICE
|
conn.workerICE = workerICE
|
||||||
@@ -207,10 +196,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
conn.wg.Add(1)
|
conn.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer conn.wg.Done()
|
defer conn.wg.Done()
|
||||||
|
|
||||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
|
||||||
conn.semaphore.Done()
|
|
||||||
|
|
||||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||||
}()
|
}()
|
||||||
conn.opened = true
|
conn.opened = true
|
||||||
@@ -670,19 +655,6 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
|
||||||
maxWait := 300
|
|
||||||
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
|
|
||||||
|
|
||||||
timeout := time.NewTimer(duration)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-timeout.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) isRelayed() bool {
|
func (conn *Conn) isRelayed() bool {
|
||||||
switch conn.currentConnPriority {
|
switch conn.currentConnPriority {
|
||||||
case conntype.Relay, conntype.ICETurn:
|
case conntype.Relay, conntype.ICETurn:
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||||
@@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
sd := ServiceDependencies{
|
sd := ServiceDependencies{
|
||||||
SrWatcher: swWatcher,
|
SrWatcher: swWatcher,
|
||||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
|
||||||
PeerConnDispatcher: testDispatcher,
|
PeerConnDispatcher: testDispatcher,
|
||||||
}
|
}
|
||||||
conn, err := NewConn(connConf, sd)
|
conn, err := NewConn(connConf, sd)
|
||||||
@@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
sd := ServiceDependencies{
|
sd := ServiceDependencies{
|
||||||
StatusRecorder: NewRecorder("https://mgm"),
|
StatusRecorder: NewRecorder("https://mgm"),
|
||||||
SrWatcher: swWatcher,
|
SrWatcher: swWatcher,
|
||||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
|
||||||
PeerConnDispatcher: testDispatcher,
|
PeerConnDispatcher: testDispatcher,
|
||||||
}
|
}
|
||||||
conn, err := NewConn(connConf, sd)
|
conn, err := NewConn(connConf, sd)
|
||||||
@@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
sd := ServiceDependencies{
|
sd := ServiceDependencies{
|
||||||
StatusRecorder: NewRecorder("https://mgm"),
|
StatusRecorder: NewRecorder("https://mgm"),
|
||||||
SrWatcher: swWatcher,
|
SrWatcher: swWatcher,
|
||||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
|
||||||
PeerConnDispatcher: testDispatcher,
|
PeerConnDispatcher: testDispatcher,
|
||||||
}
|
}
|
||||||
conn, err := NewConn(connConf, sd)
|
conn, err := NewConn(connConf, sd)
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
|
|||||||
|
|
||||||
configDir := filepath.Join(DefaultConfigPathDir, username)
|
configDir := filepath.Join(DefaultConfigPathDir, username)
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||||
if err := os.MkdirAll(configDir, 0600); err != nil {
|
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
|
|||||||
return configDir, nil
|
return configDir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileExists(path string) bool {
|
func fileExists(path string) (bool, error) {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return !os.IsNotExist(err)
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
|||||||
|
|
||||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !fileExists(input.ConfigPath) {
|
configExists, err := fileExists(input.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||||
|
}
|
||||||
|
if !configExists {
|
||||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !fileExists(input.ConfigPath) {
|
configExists, err := fileExists(input.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||||
|
}
|
||||||
|
if !configExists {
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
cfg, err := createNewConfig(input)
|
cfg, err := createNewConfig(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||||
input.PreSharedKey = nil
|
input.PreSharedKey = nil
|
||||||
}
|
}
|
||||||
err := util.EnforcePermission(input.ConfigPath)
|
err = util.EnforcePermission(input.ConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
}
|
}
|
||||||
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||||
if fileExists(configPath) {
|
configExists, err := fileExists(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configExists {
|
||||||
err := util.EnforcePermission(configPath)
|
err := util.EnforcePermission(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
|
|||||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !fileExists(input.ConfigPath) {
|
configExists, err := fileExists(input.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||||
|
}
|
||||||
|
if !configExists {
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
cfg, err := createNewConfig(input)
|
cfg, err := createNewConfig(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
profPath := filepath.Join(configDir, profileName+".json")
|
profPath := filepath.Join(configDir, profileName+".json")
|
||||||
if fileExists(profPath) {
|
profileExists, err := fileExists(profPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||||
|
}
|
||||||
|
if profileExists {
|
||||||
return ErrProfileAlreadyExists
|
return ErrProfileAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
|||||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||||
}
|
}
|
||||||
profPath := filepath.Join(configDir, profileName+".json")
|
profPath := filepath.Join(configDir, profileName+".json")
|
||||||
if !fileExists(profPath) {
|
profileExists, err := fileExists(profPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||||
|
}
|
||||||
|
if !profileExists {
|
||||||
return ErrProfileNotFound
|
return ErrProfileNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||||
if !fileExists(stateFile) {
|
stateFileExists, err := fileExists(stateFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||||
|
}
|
||||||
|
if !stateFileExists {
|
||||||
return nil, errors.New("profile state file does not exist")
|
return nil, errors.New("profile state file does not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
|
|||||||
case <-closer:
|
case <-closer:
|
||||||
return
|
return
|
||||||
case routerStates := <-subscription.Events():
|
case routerStates := <-subscription.Events():
|
||||||
peerStateUpdate <- routerStates
|
select {
|
||||||
|
case peerStateUpdate <- routerStates:
|
||||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-closer:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
80
client/internal/sleep/handler/handler.go
Normal file
80
client/internal/sleep/handler/handler.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Agent interface {
|
||||||
|
Up(ctx context.Context) error
|
||||||
|
Down(ctx context.Context) error
|
||||||
|
Status() (internal.StatusType, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SleepHandler struct {
|
||||||
|
agent Agent
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||||
|
sleepTriggeredDown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(agent Agent) *SleepHandler {
|
||||||
|
return &SleepHandler{
|
||||||
|
agent: agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.sleepTriggeredDown {
|
||||||
|
log.Info("skipping up because wasn't sleep down")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||||
|
s.sleepTriggeredDown = false
|
||||||
|
|
||||||
|
log.Info("running up after wake up")
|
||||||
|
err := s.agent.Up(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("running up failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("running up command executed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
status, err := s.agent.Status()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||||
|
log.Infof("skipping setting the agent down because status is %s", status)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("running down after system started sleeping")
|
||||||
|
|
||||||
|
if err = s.agent.Down(ctx); err != nil {
|
||||||
|
log.Errorf("running down failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
log.Info("running down executed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
153
client/internal/sleep/handler/handler_test.go
Normal file
153
client/internal/sleep/handler/handler_test.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAgent struct {
|
||||||
|
upErr error
|
||||||
|
downErr error
|
||||||
|
statusErr error
|
||||||
|
status internal.StatusType
|
||||||
|
upCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Up(_ context.Context) error {
|
||||||
|
m.upCalls++
|
||||||
|
return m.upErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Down(_ context.Context) error {
|
||||||
|
return m.downErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||||
|
return m.status, m.statusErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||||
|
agent := &mockAgent{status: status}
|
||||||
|
return New(agent), agent
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||||
|
h, _ := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
// Even if Up fails, flag should be reset
|
||||||
|
_ = h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, agent.upCalls)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
agent.upErr = errors.New("up failed")
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.upErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
_ = h.HandleWakeUp(context.Background())
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Idle", internal.StatusIdle},
|
||||||
|
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||||
|
{"LoginFailed", internal.StatusLoginFailed},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h, _ := newHandler(tt.status)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Connecting", internal.StatusConnecting},
|
||||||
|
{"Connected", internal.StatusConnected},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h, _ := newHandler(tt.status)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, h.sleepTriggeredDown)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||||
|
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||||
|
h := New(agent)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.statusErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||||
|
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||||
|
h := New(agent)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.downErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.36.6
|
// protoc-gen-go v1.36.6
|
||||||
// protoc v6.32.1
|
// protoc v6.33.3
|
||||||
// source: daemon.proto
|
// source: daemon.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
@@ -88,6 +88,58 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
|
|||||||
return file_daemon_proto_rawDescGZIP(), []int{0}
|
return file_daemon_proto_rawDescGZIP(), []int{0}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExposeProtocol int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
|
||||||
|
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
|
||||||
|
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
|
||||||
|
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Enum value maps for ExposeProtocol.
|
||||||
|
var (
|
||||||
|
ExposeProtocol_name = map[int32]string{
|
||||||
|
0: "EXPOSE_HTTP",
|
||||||
|
1: "EXPOSE_HTTPS",
|
||||||
|
2: "EXPOSE_TCP",
|
||||||
|
3: "EXPOSE_UDP",
|
||||||
|
}
|
||||||
|
ExposeProtocol_value = map[string]int32{
|
||||||
|
"EXPOSE_HTTP": 0,
|
||||||
|
"EXPOSE_HTTPS": 1,
|
||||||
|
"EXPOSE_TCP": 2,
|
||||||
|
"EXPOSE_UDP": 3,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (x ExposeProtocol) Enum() *ExposeProtocol {
|
||||||
|
p := new(ExposeProtocol)
|
||||||
|
*p = x
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x ExposeProtocol) String() string {
|
||||||
|
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
|
||||||
|
return file_daemon_proto_enumTypes[1].Descriptor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ExposeProtocol) Type() protoreflect.EnumType {
|
||||||
|
return &file_daemon_proto_enumTypes[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
|
||||||
|
return protoreflect.EnumNumber(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ExposeProtocol.Descriptor instead.
|
||||||
|
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
// avoid collision with loglevel enum
|
// avoid collision with loglevel enum
|
||||||
type OSLifecycleRequest_CycleType int32
|
type OSLifecycleRequest_CycleType int32
|
||||||
|
|
||||||
@@ -122,11 +174,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
|
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
|
||||||
return file_daemon_proto_enumTypes[1].Descriptor()
|
return file_daemon_proto_enumTypes[2].Descriptor()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
|
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
|
||||||
return &file_daemon_proto_enumTypes[1]
|
return &file_daemon_proto_enumTypes[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
|
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
|
||||||
@@ -174,11 +226,11 @@ func (x SystemEvent_Severity) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
|
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
|
||||||
return file_daemon_proto_enumTypes[2].Descriptor()
|
return file_daemon_proto_enumTypes[3].Descriptor()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (SystemEvent_Severity) Type() protoreflect.EnumType {
|
func (SystemEvent_Severity) Type() protoreflect.EnumType {
|
||||||
return &file_daemon_proto_enumTypes[2]
|
return &file_daemon_proto_enumTypes[3]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
|
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
|
||||||
@@ -229,11 +281,11 @@ func (x SystemEvent_Category) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
|
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
|
||||||
return file_daemon_proto_enumTypes[3].Descriptor()
|
return file_daemon_proto_enumTypes[4].Descriptor()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (SystemEvent_Category) Type() protoreflect.EnumType {
|
func (SystemEvent_Category) Type() protoreflect.EnumType {
|
||||||
return &file_daemon_proto_enumTypes[3]
|
return &file_daemon_proto_enumTypes[4]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
|
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
|
||||||
@@ -5600,6 +5652,224 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExposeServiceRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
|
||||||
|
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
|
||||||
|
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
|
||||||
|
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
|
||||||
|
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
|
||||||
|
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||||
|
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) Reset() {
|
||||||
|
*x = ExposeServiceRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[85]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ExposeServiceRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[85]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetPort() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Port
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
|
||||||
|
if x != nil {
|
||||||
|
return x.Protocol
|
||||||
|
}
|
||||||
|
return ExposeProtocol_EXPOSE_HTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetPin() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Pin
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetPassword() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Password
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetUserGroups() []string {
|
||||||
|
if x != nil {
|
||||||
|
return x.UserGroups
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetDomain() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Domain
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceRequest) GetNamePrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.NamePrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExposeServiceEvent struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// Types that are valid to be assigned to Event:
|
||||||
|
//
|
||||||
|
// *ExposeServiceEvent_Ready
|
||||||
|
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceEvent) Reset() {
|
||||||
|
*x = ExposeServiceEvent{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[86]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceEvent) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ExposeServiceEvent) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[86]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
|
||||||
|
if x != nil {
|
||||||
|
return x.Event
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
|
||||||
|
if x != nil {
|
||||||
|
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
|
||||||
|
return x.Ready
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type isExposeServiceEvent_Event interface {
|
||||||
|
isExposeServiceEvent_Event()
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExposeServiceEvent_Ready struct {
|
||||||
|
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
|
||||||
|
|
||||||
|
type ExposeServiceReady struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
|
||||||
|
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
|
||||||
|
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) Reset() {
|
||||||
|
*x = ExposeServiceReady{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[87]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ExposeServiceReady) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[87]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) GetServiceName() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ServiceName
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) GetServiceUrl() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ServiceUrl
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ExposeServiceReady) GetDomain() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Domain
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
type PortInfo_Range struct {
|
type PortInfo_Range struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||||
@@ -5610,7 +5880,7 @@ type PortInfo_Range struct {
|
|||||||
|
|
||||||
func (x *PortInfo_Range) Reset() {
|
func (x *PortInfo_Range) Reset() {
|
||||||
*x = PortInfo_Range{}
|
*x = PortInfo_Range{}
|
||||||
mi := &file_daemon_proto_msgTypes[86]
|
mi := &file_daemon_proto_msgTypes[89]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -5622,7 +5892,7 @@ func (x *PortInfo_Range) String() string {
|
|||||||
func (*PortInfo_Range) ProtoMessage() {}
|
func (*PortInfo_Range) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[86]
|
mi := &file_daemon_proto_msgTypes[89]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -6149,7 +6419,25 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x16InstallerResultRequest\"O\n" +
|
"\x16InstallerResultRequest\"O\n" +
|
||||||
"\x17InstallerResultResponse\x12\x18\n" +
|
"\x17InstallerResultResponse\x12\x18\n" +
|
||||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||||
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
|
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
|
||||||
|
"\x14ExposeServiceRequest\x12\x12\n" +
|
||||||
|
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
|
||||||
|
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
|
||||||
|
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
|
||||||
|
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
|
||||||
|
"\vuser_groups\x18\x05 \x03(\tR\n" +
|
||||||
|
"userGroups\x12\x16\n" +
|
||||||
|
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
|
||||||
|
"\vname_prefix\x18\a \x01(\tR\n" +
|
||||||
|
"namePrefix\"Q\n" +
|
||||||
|
"\x12ExposeServiceEvent\x122\n" +
|
||||||
|
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
|
||||||
|
"\x05event\"p\n" +
|
||||||
|
"\x12ExposeServiceReady\x12!\n" +
|
||||||
|
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
|
||||||
|
"\vservice_url\x18\x02 \x01(\tR\n" +
|
||||||
|
"serviceUrl\x12\x16\n" +
|
||||||
|
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
|
||||||
"\bLogLevel\x12\v\n" +
|
"\bLogLevel\x12\v\n" +
|
||||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||||
"\x05PANIC\x10\x01\x12\t\n" +
|
"\x05PANIC\x10\x01\x12\t\n" +
|
||||||
@@ -6158,7 +6446,14 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x04WARN\x10\x04\x12\b\n" +
|
"\x04WARN\x10\x04\x12\b\n" +
|
||||||
"\x04INFO\x10\x05\x12\t\n" +
|
"\x04INFO\x10\x05\x12\t\n" +
|
||||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||||
"\x05TRACE\x10\a2\xdd\x14\n" +
|
"\x05TRACE\x10\a*S\n" +
|
||||||
|
"\x0eExposeProtocol\x12\x0f\n" +
|
||||||
|
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
|
||||||
|
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
|
||||||
|
"\n" +
|
||||||
|
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
|
||||||
|
"\n" +
|
||||||
|
"EXPOSE_UDP\x10\x032\xac\x15\n" +
|
||||||
"\rDaemonService\x126\n" +
|
"\rDaemonService\x126\n" +
|
||||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||||
@@ -6197,7 +6492,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
|
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
|
||||||
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
|
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
|
||||||
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
||||||
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
|
||||||
|
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
file_daemon_proto_rawDescOnce sync.Once
|
file_daemon_proto_rawDescOnce sync.Once
|
||||||
@@ -6211,214 +6507,222 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
|||||||
return file_daemon_proto_rawDescData
|
return file_daemon_proto_rawDescData
|
||||||
}
|
}
|
||||||
|
|
||||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
|
||||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
|
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
|
||||||
var file_daemon_proto_goTypes = []any{
|
var file_daemon_proto_goTypes = []any{
|
||||||
(LogLevel)(0), // 0: daemon.LogLevel
|
(LogLevel)(0), // 0: daemon.LogLevel
|
||||||
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||||
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity
|
(OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
|
||||||
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category
|
(SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
|
||||||
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest
|
(SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
|
||||||
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest
|
(*EmptyRequest)(nil), // 5: daemon.EmptyRequest
|
||||||
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse
|
(*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
|
||||||
(*LoginRequest)(nil), // 7: daemon.LoginRequest
|
(*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
|
||||||
(*LoginResponse)(nil), // 8: daemon.LoginResponse
|
(*LoginRequest)(nil), // 8: daemon.LoginRequest
|
||||||
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest
|
(*LoginResponse)(nil), // 9: daemon.LoginResponse
|
||||||
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse
|
(*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
|
||||||
(*UpRequest)(nil), // 11: daemon.UpRequest
|
(*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
|
||||||
(*UpResponse)(nil), // 12: daemon.UpResponse
|
(*UpRequest)(nil), // 12: daemon.UpRequest
|
||||||
(*StatusRequest)(nil), // 13: daemon.StatusRequest
|
(*UpResponse)(nil), // 13: daemon.UpResponse
|
||||||
(*StatusResponse)(nil), // 14: daemon.StatusResponse
|
(*StatusRequest)(nil), // 14: daemon.StatusRequest
|
||||||
(*DownRequest)(nil), // 15: daemon.DownRequest
|
(*StatusResponse)(nil), // 15: daemon.StatusResponse
|
||||||
(*DownResponse)(nil), // 16: daemon.DownResponse
|
(*DownRequest)(nil), // 16: daemon.DownRequest
|
||||||
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest
|
(*DownResponse)(nil), // 17: daemon.DownResponse
|
||||||
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse
|
(*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
|
||||||
(*PeerState)(nil), // 19: daemon.PeerState
|
(*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
|
||||||
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState
|
(*PeerState)(nil), // 20: daemon.PeerState
|
||||||
(*SignalState)(nil), // 21: daemon.SignalState
|
(*LocalPeerState)(nil), // 21: daemon.LocalPeerState
|
||||||
(*ManagementState)(nil), // 22: daemon.ManagementState
|
(*SignalState)(nil), // 22: daemon.SignalState
|
||||||
(*RelayState)(nil), // 23: daemon.RelayState
|
(*ManagementState)(nil), // 23: daemon.ManagementState
|
||||||
(*NSGroupState)(nil), // 24: daemon.NSGroupState
|
(*RelayState)(nil), // 24: daemon.RelayState
|
||||||
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo
|
(*NSGroupState)(nil), // 25: daemon.NSGroupState
|
||||||
(*SSHServerState)(nil), // 26: daemon.SSHServerState
|
(*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
|
||||||
(*FullStatus)(nil), // 27: daemon.FullStatus
|
(*SSHServerState)(nil), // 27: daemon.SSHServerState
|
||||||
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest
|
(*FullStatus)(nil), // 28: daemon.FullStatus
|
||||||
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse
|
(*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
|
||||||
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest
|
(*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
|
||||||
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse
|
(*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
|
||||||
(*IPList)(nil), // 32: daemon.IPList
|
(*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
|
||||||
(*Network)(nil), // 33: daemon.Network
|
(*IPList)(nil), // 33: daemon.IPList
|
||||||
(*PortInfo)(nil), // 34: daemon.PortInfo
|
(*Network)(nil), // 34: daemon.Network
|
||||||
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule
|
(*PortInfo)(nil), // 35: daemon.PortInfo
|
||||||
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse
|
(*ForwardingRule)(nil), // 36: daemon.ForwardingRule
|
||||||
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest
|
(*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
|
||||||
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse
|
(*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
|
||||||
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest
|
(*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
|
||||||
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse
|
(*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
|
||||||
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest
|
(*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
|
||||||
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse
|
(*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
|
||||||
(*State)(nil), // 43: daemon.State
|
(*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
|
||||||
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest
|
(*State)(nil), // 44: daemon.State
|
||||||
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse
|
(*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
|
||||||
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest
|
(*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
|
||||||
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse
|
(*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
|
||||||
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest
|
(*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
|
||||||
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse
|
(*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
|
||||||
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest
|
(*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
|
||||||
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse
|
(*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
|
||||||
(*TCPFlags)(nil), // 52: daemon.TCPFlags
|
(*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
|
||||||
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest
|
(*TCPFlags)(nil), // 53: daemon.TCPFlags
|
||||||
(*TraceStage)(nil), // 54: daemon.TraceStage
|
(*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
|
||||||
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse
|
(*TraceStage)(nil), // 55: daemon.TraceStage
|
||||||
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest
|
(*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
|
||||||
(*SystemEvent)(nil), // 57: daemon.SystemEvent
|
(*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
|
||||||
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest
|
(*SystemEvent)(nil), // 58: daemon.SystemEvent
|
||||||
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse
|
(*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
|
||||||
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest
|
(*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
|
||||||
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse
|
(*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
|
||||||
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest
|
(*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
|
||||||
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse
|
(*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
|
||||||
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest
|
(*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
|
||||||
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse
|
(*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
|
||||||
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest
|
(*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
|
||||||
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse
|
(*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
|
||||||
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest
|
(*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
|
||||||
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse
|
(*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
|
||||||
(*Profile)(nil), // 70: daemon.Profile
|
(*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
|
||||||
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest
|
(*Profile)(nil), // 71: daemon.Profile
|
||||||
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse
|
(*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
|
||||||
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest
|
(*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
|
||||||
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse
|
(*LogoutRequest)(nil), // 74: daemon.LogoutRequest
|
||||||
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest
|
(*LogoutResponse)(nil), // 75: daemon.LogoutResponse
|
||||||
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse
|
(*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
|
||||||
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
|
(*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
|
||||||
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
|
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
|
||||||
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
|
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
|
||||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
|
||||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
|
||||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
|
||||||
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
|
||||||
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
|
||||||
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
|
||||||
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
|
||||||
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
|
||||||
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
|
||||||
nil, // 89: daemon.Network.ResolvedIPsEntry
|
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
|
||||||
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
|
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
|
||||||
nil, // 91: daemon.SystemEvent.MetadataEntry
|
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
|
||||||
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
|
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
|
||||||
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
|
nil, // 93: daemon.Network.ResolvedIPsEntry
|
||||||
|
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
|
||||||
|
nil, // 95: daemon.SystemEvent.MetadataEntry
|
||||||
|
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
|
||||||
|
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
|
||||||
}
|
}
|
||||||
var file_daemon_proto_depIdxs = []int32{
|
var file_daemon_proto_depIdxs = []int32{
|
||||||
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||||
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||||
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||||
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||||
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||||
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||||
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||||
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||||
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||||
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||||
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||||
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||||
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||||
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||||
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||||
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||||
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||||
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||||
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||||
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||||
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||||
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||||
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||||
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||||
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||||
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||||
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||||
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||||
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||||
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||||
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||||
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||||
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||||
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||||
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||||
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||||
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||||
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||||
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||||
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||||
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||||
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||||
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||||
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||||
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||||
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||||
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||||
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||||
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||||
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||||
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||||
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||||
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||||
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||||
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||||
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||||
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||||
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||||
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||||
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||||
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||||
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||||
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||||
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||||
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||||
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||||
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||||
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||||
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||||
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||||
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||||
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||||
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||||
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||||
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||||
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||||
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||||
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||||
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||||
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||||
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||||
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||||
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||||
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||||
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||||
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||||
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||||
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||||
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||||
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||||
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||||
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||||
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||||
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||||
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||||
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||||
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||||
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||||
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||||
69, // [69:104] is the sub-list for method output_type
|
87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||||
34, // [34:69] is the sub-list for method input_type
|
7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||||
34, // [34:34] is the sub-list for extension type_name
|
89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||||
34, // [34:34] is the sub-list for extension extendee
|
91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||||
0, // [0:34] is the sub-list for field type_name
|
72, // [72:108] is the sub-list for method output_type
|
||||||
|
36, // [36:72] is the sub-list for method input_type
|
||||||
|
36, // [36:36] is the sub-list for extension type_name
|
||||||
|
36, // [36:36] is the sub-list for extension extendee
|
||||||
|
0, // [0:36] is the sub-list for field type_name
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() { file_daemon_proto_init() }
|
func init() { file_daemon_proto_init() }
|
||||||
@@ -6439,13 +6743,16 @@ func file_daemon_proto_init() {
|
|||||||
file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
|
file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
|
||||||
file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
|
file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
|
||||||
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
|
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
|
||||||
|
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
|
||||||
|
(*ExposeServiceEvent_Ready)(nil),
|
||||||
|
}
|
||||||
type x struct{}
|
type x struct{}
|
||||||
out := protoimpl.TypeBuilder{
|
out := protoimpl.TypeBuilder{
|
||||||
File: protoimpl.DescBuilder{
|
File: protoimpl.DescBuilder{
|
||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||||
NumEnums: 4,
|
NumEnums: 5,
|
||||||
NumMessages: 88,
|
NumMessages: 91,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ service DaemonService {
|
|||||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
|
|
||||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||||
|
|
||||||
|
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||||
|
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -801,3 +804,32 @@ message InstallerResultResponse {
|
|||||||
bool success = 1;
|
bool success = 1;
|
||||||
string errorMsg = 2;
|
string errorMsg = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ExposeProtocol {
|
||||||
|
EXPOSE_HTTP = 0;
|
||||||
|
EXPOSE_HTTPS = 1;
|
||||||
|
EXPOSE_TCP = 2;
|
||||||
|
EXPOSE_UDP = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExposeServiceRequest {
|
||||||
|
uint32 port = 1;
|
||||||
|
ExposeProtocol protocol = 2;
|
||||||
|
string pin = 3;
|
||||||
|
string password = 4;
|
||||||
|
repeated string user_groups = 5;
|
||||||
|
string domain = 6;
|
||||||
|
string name_prefix = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExposeServiceEvent {
|
||||||
|
oneof event {
|
||||||
|
ExposeServiceReady ready = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExposeServiceReady {
|
||||||
|
string service_name = 1;
|
||||||
|
string service_url = 2;
|
||||||
|
string domain = 3;
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ type DaemonServiceClient interface {
|
|||||||
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
|
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
|
||||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||||
|
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||||
|
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -424,6 +426,38 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
|
||||||
|
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &daemonServiceExposeServiceClient{stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type DaemonService_ExposeServiceClient interface {
|
||||||
|
Recv() (*ExposeServiceEvent, error)
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type daemonServiceExposeServiceClient struct {
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
|
||||||
|
m := new(ExposeServiceEvent)
|
||||||
|
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -486,6 +520,8 @@ type DaemonServiceServer interface {
|
|||||||
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
|
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
|
||||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||||
|
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||||
|
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -598,6 +634,9 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
|
|||||||
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
|
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
|
||||||
|
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -1244,6 +1283,27 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(ExposeServiceRequest)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
type DaemonService_ExposeServiceServer interface {
|
||||||
|
Send(*ExposeServiceEvent) error
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type daemonServiceExposeServiceServer struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
|
||||||
|
return x.ServerStream.SendMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -1394,6 +1454,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
Handler: _DaemonService_SubscribeEvents_Handler,
|
Handler: _DaemonService_SubscribeEvents_Handler,
|
||||||
ServerStreams: true,
|
ServerStreams: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
StreamName: "ExposeService",
|
||||||
|
Handler: _DaemonService_ExposeService_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Metadata: "daemon.proto",
|
Metadata: "daemon.proto",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
|
||||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
|
||||||
switch req.GetType() {
|
|
||||||
case proto.OSLifecycleRequest_WAKEUP:
|
|
||||||
return s.handleWakeUp(callerCtx)
|
|
||||||
case proto.OSLifecycleRequest_SLEEP:
|
|
||||||
return s.handleSleep(callerCtx)
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
|
||||||
}
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
|
||||||
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
|
||||||
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
|
||||||
if !s.sleepTriggeredDown.Load() {
|
|
||||||
log.Info("skipping up because wasn't sleep down")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
|
||||||
s.sleepTriggeredDown.Store(false)
|
|
||||||
|
|
||||||
log.Info("running up after wake up")
|
|
||||||
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("running up failed: %v", err)
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("running up command executed successfully")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
|
||||||
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
status, err := state.Status()
|
|
||||||
if err != nil {
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
|
||||||
log.Infof("skipping setting the agent down because status is %s", status)
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Info("running down after system started sleeping")
|
|
||||||
|
|
||||||
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("running down failed: %v", err)
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
log.Info("running down executed successfully")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestServer() *Server {
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
|
||||||
return &Server{
|
|
||||||
rootCtx: ctx,
|
|
||||||
statusRecorder: peer.NewRecorder(""),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// sleepTriggeredDown is false by default
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusIdle)
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusNeedsLogin)
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusConnecting)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusConnected)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// Manually set the flag to simulate prior sleep down
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
|
||||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// First wakeup without prior sleep - should be no-op
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
// Simulate prior sleep
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// First wakeup after sleep - should reset flag
|
|
||||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
// Second wakeup - should be no-op
|
|
||||||
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
resp, err := s.handleWakeUp(context.Background())
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// Even if Up fails, flag should be reset
|
|
||||||
_, _ = s.handleWakeUp(context.Background())
|
|
||||||
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status internal.StatusType
|
|
||||||
}{
|
|
||||||
{"Idle", internal.StatusIdle},
|
|
||||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
|
||||||
{"LoginFailed", internal.StatusLoginFailed},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(tt.status)
|
|
||||||
|
|
||||||
resp, err := s.handleSleep(context.Background())
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status internal.StatusType
|
|
||||||
}{
|
|
||||||
{"Connecting", internal.StatusConnecting},
|
|
||||||
{"Connected", internal.StatusConnected},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(tt.status)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.handleSleep(ctx)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -21,7 +21,9 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -85,8 +87,7 @@ type Server struct {
|
|||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
|
||||||
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
sleepHandler *sleephandler.SleepHandler
|
||||||
sleepTriggeredDown atomic.Bool
|
|
||||||
|
|
||||||
jwtCache *jwtCache
|
jwtCache *jwtCache
|
||||||
}
|
}
|
||||||
@@ -100,7 +101,7 @@ type oauthAuthFlow struct {
|
|||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
||||||
return &Server{
|
s := &Server{
|
||||||
rootCtx: ctx,
|
rootCtx: ctx,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
persistSyncResponse: true,
|
persistSyncResponse: true,
|
||||||
@@ -110,6 +111,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
updateSettingsDisabled: updateSettingsDisabled,
|
updateSettingsDisabled: updateSettingsDisabled,
|
||||||
jwtCache: newJWTCache(),
|
jwtCache: newJWTCache(),
|
||||||
}
|
}
|
||||||
|
agent := &serverAgent{s}
|
||||||
|
s.sleepHandler = sleephandler.New(agent)
|
||||||
|
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
@@ -636,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
|
|
||||||
return s.waitForUp(callerCtx)
|
return s.waitForUp(callerCtx)
|
||||||
}
|
}
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
||||||
log.Warnf(errRestoreResidualState, err)
|
log.Warnf(errRestoreResidualState, err)
|
||||||
}
|
}
|
||||||
@@ -649,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
// not in the progress or already successfully established connection.
|
// not in the progress or already successfully established connection.
|
||||||
status, err := state.Status()
|
status, err := state.Status()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if status != internal.StatusIdle {
|
if status != internal.StatusIdle {
|
||||||
|
s.mutex.Unlock()
|
||||||
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -669,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
s.actCancel = cancel
|
s.actCancel = cancel
|
||||||
|
|
||||||
if s.config == nil {
|
if s.config == nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||||
}
|
}
|
||||||
|
|
||||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
log.Errorf("failed to get active profile state: %v", err)
|
log.Errorf("failed to get active profile state: %v", err)
|
||||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg != nil && msg.ProfileName != nil {
|
if msg != nil && msg.ProfileName != nil {
|
||||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
log.Errorf("failed to switch profile: %v", err)
|
log.Errorf("failed to switch profile: %v", err)
|
||||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||||
}
|
}
|
||||||
@@ -687,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
|
|
||||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
log.Errorf("failed to get active profile state: %v", err)
|
log.Errorf("failed to get active profile state: %v", err)
|
||||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||||
}
|
}
|
||||||
@@ -695,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
|
|
||||||
config, _, err := s.getConfig(activeProf)
|
config, _, err := s.getConfig(activeProf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
log.Errorf("failed to get active profile config: %v", err)
|
log.Errorf("failed to get active profile config: %v", err)
|
||||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||||
}
|
}
|
||||||
@@ -713,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
}
|
}
|
||||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
|
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
|
||||||
|
|
||||||
|
s.mutex.Unlock()
|
||||||
return s.waitForUp(callerCtx)
|
return s.waitForUp(callerCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -838,15 +849,27 @@ func (s *Server) cleanupConnection() error {
|
|||||||
if s.actCancel == nil {
|
if s.actCancel == nil {
|
||||||
return ErrServiceNotUp
|
return ErrServiceNotUp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture the engine reference before cancelling the context.
|
||||||
|
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||||
|
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||||
|
// to skip the engine shutdown entirely.
|
||||||
|
var engine *internal.Engine
|
||||||
|
if s.connectClient != nil {
|
||||||
|
engine = s.connectClient.Engine()
|
||||||
|
}
|
||||||
|
|
||||||
s.actCancel()
|
s.actCancel()
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.connectClient.Stop(); err != nil {
|
if engine != nil {
|
||||||
|
if err := engine.Stop(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.connectClient = nil
|
s.connectClient = nil
|
||||||
s.isSessionActive.Store(false)
|
s.isSessionActive.Store(false)
|
||||||
@@ -1312,6 +1335,60 @@ func (s *Server) WaitJWTToken(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||||
|
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
if !s.clientRunning {
|
||||||
|
s.mutex.Unlock()
|
||||||
|
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
|
||||||
|
}
|
||||||
|
connectClient := s.connectClient
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if connectClient == nil {
|
||||||
|
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := engine.GetExposeManager()
|
||||||
|
if mgr == nil {
|
||||||
|
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := srv.Context()
|
||||||
|
|
||||||
|
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer exposeCancel()
|
||||||
|
|
||||||
|
mgmReq := expose.NewRequest(req)
|
||||||
|
result, err := mgr.Expose(exposeCtx, *mgmReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.Send(&proto.ExposeServiceEvent{
|
||||||
|
Event: &proto.ExposeServiceEvent_Ready{
|
||||||
|
Ready: &proto.ExposeServiceReady{
|
||||||
|
ServiceName: result.ServiceName,
|
||||||
|
ServiceUrl: result.ServiceURL,
|
||||||
|
Domain: result.Domain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mgr.KeepAlive(ctx, result.Domain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
return false
|
return false
|
||||||
|
|||||||
46
client/server/sleep.go
Normal file
46
client/server/sleep.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||||
|
type serverAgent struct {
|
||||||
|
s *Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Up(ctx context.Context) error {
|
||||||
|
_, err := a.s.Up(ctx, &proto.UpRequest{})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Down(ctx context.Context) error {
|
||||||
|
_, err := a.s.Down(ctx, &proto.DownRequest{})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||||
|
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||||
|
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||||
|
switch req.GetType() {
|
||||||
|
case proto.OSLifecycleRequest_WAKEUP:
|
||||||
|
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
case proto.OSLifecycleRequest_SLEEP:
|
||||||
|
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||||
|
}
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return DefaultDaemonAddrWindows
|
return DefaultDaemonAddrWindows
|
||||||
}
|
}
|
||||||
return DefaultDaemonAddr
|
return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOptions contains options for SSH connections
|
// DialOptions contains options for SSH connections
|
||||||
|
|||||||
@@ -46,8 +46,10 @@ const (
|
|||||||
cmdSFTP = "<sftp>"
|
cmdSFTP = "<sftp>"
|
||||||
cmdNonInteractive = "<idle>"
|
cmdNonInteractive = "<idle>"
|
||||||
|
|
||||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
|
||||||
DefaultJWTMaxTokenAge = 5 * 60
|
// Set to 10 minutes to accommodate identity providers like Azure Entra ID
|
||||||
|
// that backdate the iat claim by up to 5 minutes.
|
||||||
|
DefaultJWTMaxTokenAge = 10 * 60
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -70,6 +71,8 @@ type ServerConfig struct {
|
|||||||
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth AuthConfig `yaml:"auth"`
|
||||||
Store StoreConfig `yaml:"store"`
|
Store StoreConfig `yaml:"store"`
|
||||||
|
ActivityStore StoreConfig `yaml:"activityStore"`
|
||||||
|
AuthStore StoreConfig `yaml:"authStore"`
|
||||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,6 +174,7 @@ type StoreConfig struct {
|
|||||||
Engine string `yaml:"engine"`
|
Engine string `yaml:"engine"`
|
||||||
EncryptionKey string `yaml:"encryptionKey"`
|
EncryptionKey string `yaml:"encryptionKey"`
|
||||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||||
|
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReverseProxyConfig contains reverse proxy settings
|
// ReverseProxyConfig contains reverse proxy settings
|
||||||
@@ -532,6 +536,74 @@ func stripSignalProtocol(uri string) string {
|
|||||||
return uri
|
return uri
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
|
||||||
|
var ttl time.Duration
|
||||||
|
if relays.CredentialsTTL != "" {
|
||||||
|
var err error
|
||||||
|
ttl, err = time.ParseDuration(relays.CredentialsTTL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &nbconfig.Relay{
|
||||||
|
Addresses: relays.Addresses,
|
||||||
|
CredentialsTTL: util.Duration{Duration: ttl},
|
||||||
|
Secret: relays.Secret,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
|
||||||
|
// authStore overrides auth.storage when set.
|
||||||
|
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
|
||||||
|
authStorageType := mgmt.Auth.Storage.Type
|
||||||
|
authStorageDSN := c.Server.AuthStore.DSN
|
||||||
|
if c.Server.AuthStore.Engine != "" {
|
||||||
|
authStorageType = c.Server.AuthStore.Engine
|
||||||
|
}
|
||||||
|
if authStorageType == "" {
|
||||||
|
authStorageType = "sqlite3"
|
||||||
|
}
|
||||||
|
authStorageFile := ""
|
||||||
|
if authStorageType == "postgres" {
|
||||||
|
if authStorageDSN == "" {
|
||||||
|
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||||
|
if c.Server.AuthStore.File != "" {
|
||||||
|
authStorageFile = c.Server.AuthStore.File
|
||||||
|
if !filepath.IsAbs(authStorageFile) {
|
||||||
|
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &idp.EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: mgmt.Auth.Issuer,
|
||||||
|
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||||
|
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||||
|
Storage: idp.EmbeddedStorageConfig{
|
||||||
|
Type: authStorageType,
|
||||||
|
Config: idp.EmbeddedStorageTypeConfig{
|
||||||
|
File: authStorageFile,
|
||||||
|
DSN: authStorageDSN,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||||
|
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||||
|
cfg.Owner = &idp.OwnerConfig{
|
||||||
|
Email: mgmt.Auth.Owner.Email,
|
||||||
|
Hash: mgmt.Auth.Owner.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ToManagementConfig converts CombinedConfig to management server config
|
// ToManagementConfig converts CombinedConfig to management server config
|
||||||
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||||
mgmt := c.Management
|
mgmt := c.Management
|
||||||
@@ -550,19 +622,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
|||||||
// Build relay config
|
// Build relay config
|
||||||
var relayConfig *nbconfig.Relay
|
var relayConfig *nbconfig.Relay
|
||||||
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
||||||
var ttl time.Duration
|
relay, err := buildRelayConfig(mgmt.Relays)
|
||||||
if mgmt.Relays.CredentialsTTL != "" {
|
|
||||||
var err error
|
|
||||||
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
|
return nil, err
|
||||||
}
|
|
||||||
}
|
|
||||||
relayConfig = &nbconfig.Relay{
|
|
||||||
Addresses: mgmt.Relays.Addresses,
|
|
||||||
CredentialsTTL: util.Duration{Duration: ttl},
|
|
||||||
Secret: mgmt.Relays.Secret,
|
|
||||||
}
|
}
|
||||||
|
relayConfig = relay
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build signal config
|
// Build signal config
|
||||||
@@ -598,31 +662,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
|||||||
httpConfig := &nbconfig.HttpServerConfig{}
|
httpConfig := &nbconfig.HttpServerConfig{}
|
||||||
|
|
||||||
// Build embedded IDP config (always enabled in combined server)
|
// Build embedded IDP config (always enabled in combined server)
|
||||||
storageFile := mgmt.Auth.Storage.File
|
embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
|
||||||
if storageFile == "" {
|
if err != nil {
|
||||||
storageFile = path.Join(mgmt.DataDir, "idp.db")
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
embeddedIdP := &idp.EmbeddedIdPConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Issuer: mgmt.Auth.Issuer,
|
|
||||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
|
||||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
|
||||||
Storage: idp.EmbeddedStorageConfig{
|
|
||||||
Type: mgmt.Auth.Storage.Type,
|
|
||||||
Config: idp.EmbeddedStorageTypeConfig{
|
|
||||||
File: storageFile,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
|
||||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
|
||||||
}
|
|
||||||
|
|
||||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
|
||||||
embeddedIdP.Owner = &idp.OwnerConfig{
|
|
||||||
Email: mgmt.Auth.Owner.Email,
|
|
||||||
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set HTTP config fields for embedded IDP
|
// Set HTTP config fields for embedded IDP
|
||||||
|
|||||||
@@ -140,6 +140,23 @@ func initializeConfig() error {
|
|||||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if file := config.Server.Store.File; file != "" {
|
||||||
|
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine := config.Server.ActivityStore.Engine; engine != "" {
|
||||||
|
engineLower := strings.ToLower(engine)
|
||||||
|
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
|
||||||
|
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
|
||||||
|
}
|
||||||
|
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
|
||||||
|
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
|
||||||
|
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if file := config.Server.ActivityStore.File; file != "" {
|
||||||
|
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("Starting combined NetBird server")
|
log.Infof("Starting combined NetBird server")
|
||||||
logConfig(config)
|
logConfig(config)
|
||||||
@@ -476,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
|||||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||||
mgmt := cfg.Management
|
mgmt := cfg.Management
|
||||||
|
|
||||||
dnsDomain := mgmt.DnsDomain
|
|
||||||
singleAccModeDomain := dnsDomain
|
|
||||||
|
|
||||||
// Extract port from listen address
|
// Extract port from listen address
|
||||||
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -490,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
mgmtSrv := mgmtServer.NewServer(
|
mgmtSrv := mgmtServer.NewServer(
|
||||||
&mgmtServer.Config{
|
&mgmtServer.Config{
|
||||||
NbConfig: mgmtConfig,
|
NbConfig: mgmtConfig,
|
||||||
DNSDomain: dnsDomain,
|
DNSDomain: "",
|
||||||
MgmtSingleAccModeDomain: singleAccModeDomain,
|
MgmtSingleAccModeDomain: "",
|
||||||
|
AutoResolveDomains: true,
|
||||||
MgmtPort: mgmtPort,
|
MgmtPort: mgmtPort,
|
||||||
MgmtMetricsPort: cfg.Server.MetricsPort,
|
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||||
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||||
@@ -668,8 +683,11 @@ func logEnvVars() {
|
|||||||
if strings.HasPrefix(env, "NB_") {
|
if strings.HasPrefix(env, "NB_") {
|
||||||
key, _, _ := strings.Cut(env, "=")
|
key, _, _ := strings.Cut(env, "=")
|
||||||
value := os.Getenv(key)
|
value := os.Getenv(key)
|
||||||
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
|
keyLower := strings.ToLower(key)
|
||||||
|
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
|
||||||
value = maskSecret(value)
|
value = maskSecret(value)
|
||||||
|
} else if strings.Contains(keyLower, "dsn") {
|
||||||
|
value = maskDSNPassword(value)
|
||||||
}
|
}
|
||||||
log.Infof(" %s=%s", key, value)
|
log.Infof(" %s=%s", key, value)
|
||||||
found = true
|
found = true
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
|||||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if file := cfg.Server.Store.File; file != "" {
|
||||||
|
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||||
|
}
|
||||||
|
|
||||||
datadir := cfg.Management.DataDir
|
datadir := cfg.Management.DataDir
|
||||||
engine := types.Engine(cfg.Management.Store.Engine)
|
engine := types.Engine(cfg.Management.Store.Engine)
|
||||||
|
|||||||
@@ -103,6 +103,19 @@ server:
|
|||||||
engine: "sqlite" # sqlite, postgres, or mysql
|
engine: "sqlite" # sqlite, postgres, or mysql
|
||||||
dsn: "" # Connection string for postgres or mysql
|
dsn: "" # Connection string for postgres or mysql
|
||||||
encryptionKey: ""
|
encryptionKey: ""
|
||||||
|
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
|
||||||
|
|
||||||
|
# Activity events store configuration (optional, defaults to sqlite in dataDir)
|
||||||
|
# activityStore:
|
||||||
|
# engine: "sqlite" # sqlite or postgres
|
||||||
|
# dsn: "" # Connection string for postgres
|
||||||
|
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
|
||||||
|
|
||||||
|
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
|
||||||
|
# authStore:
|
||||||
|
# engine: "sqlite3" # sqlite3 or postgres
|
||||||
|
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
|
||||||
|
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
|
||||||
|
|
||||||
# Reverse proxy settings (optional)
|
# Reverse proxy settings (optional)
|
||||||
# reverseProxy:
|
# reverseProxy:
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -195,11 +198,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
|||||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||||
}
|
}
|
||||||
return (&sql.SQLite3{File: file}).Open(logger)
|
return (&sql.SQLite3{File: file}).Open(logger)
|
||||||
|
case "postgres":
|
||||||
|
dsn, _ := s.Config["dsn"].(string)
|
||||||
|
if dsn == "" {
|
||||||
|
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
|
||||||
|
}
|
||||||
|
pg, err := parsePostgresDSN(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
|
||||||
|
}
|
||||||
|
return pg.Open(logger)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
|
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parsePostgresDSN parses a DSN into a sql.Postgres config.
|
||||||
|
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
|
||||||
|
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
|
||||||
|
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
|
||||||
|
var params map[string]string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||||
|
params, err = parsePostgresURI(dsn)
|
||||||
|
} else {
|
||||||
|
params, err = parsePostgresKeyValue(dsn)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
host := params["host"]
|
||||||
|
if host == "" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
var port uint16 = 5432
|
||||||
|
if p, ok := params["port"]; ok && p != "" {
|
||||||
|
v, err := strconv.ParseUint(p, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid port %q: %w", p, err)
|
||||||
|
}
|
||||||
|
if v == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
|
||||||
|
}
|
||||||
|
port = uint16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbname := params["dbname"]
|
||||||
|
if dbname == "" {
|
||||||
|
return nil, fmt.Errorf("dbname is required in DSN")
|
||||||
|
}
|
||||||
|
|
||||||
|
pg := &sql.Postgres{
|
||||||
|
NetworkDB: sql.NetworkDB{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
Database: dbname,
|
||||||
|
User: params["user"],
|
||||||
|
Password: params["password"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslMode := params["sslmode"]; sslMode != "" {
|
||||||
|
switch sslMode {
|
||||||
|
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
|
||||||
|
pg.SSL.Mode = sslMode
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
|
||||||
|
func parsePostgresURI(dsn string) (map[string]string, error) {
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid postgres URI: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := make(map[string]string)
|
||||||
|
|
||||||
|
if u.User != nil {
|
||||||
|
params["user"] = u.User.Username()
|
||||||
|
if p, ok := u.User.Password(); ok {
|
||||||
|
params["password"] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if u.Hostname() != "" {
|
||||||
|
params["host"] = u.Hostname()
|
||||||
|
}
|
||||||
|
if u.Port() != "" {
|
||||||
|
params["port"] = u.Port()
|
||||||
|
}
|
||||||
|
|
||||||
|
dbname := strings.TrimPrefix(u.Path, "/")
|
||||||
|
if dbname != "" {
|
||||||
|
params["dbname"] = dbname
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range u.Query() {
|
||||||
|
if len(v) > 0 {
|
||||||
|
params[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
|
||||||
|
// (e.g., password='my pass' host=localhost).
|
||||||
|
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
|
||||||
|
params := make(map[string]string)
|
||||||
|
s := strings.TrimSpace(dsn)
|
||||||
|
|
||||||
|
for s != "" {
|
||||||
|
eqIdx := strings.IndexByte(s, '=')
|
||||||
|
if eqIdx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(s[:eqIdx])
|
||||||
|
|
||||||
|
value, rest, err := parseDSNValue(s[eqIdx+1:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w for key %q", err, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
params[key] = value
|
||||||
|
s = strings.TrimSpace(rest)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
|
||||||
|
// It returns the parsed value and the remaining unparsed string.
|
||||||
|
func parseDSNValue(s string) (value, rest string, err error) {
|
||||||
|
if len(s) > 0 && s[0] == '\'' {
|
||||||
|
return parseQuotedDSNValue(s[1:])
|
||||||
|
}
|
||||||
|
// Unquoted value: read until whitespace.
|
||||||
|
idx := strings.IndexAny(s, " \t\n")
|
||||||
|
if idx < 0 {
|
||||||
|
return s, "", nil
|
||||||
|
}
|
||||||
|
return s[:idx], s[idx:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
|
||||||
|
// Libpq uses ” to represent a literal single quote inside quoted values.
|
||||||
|
func parseQuotedDSNValue(s string) (value, rest string, err error) {
|
||||||
|
var buf strings.Builder
|
||||||
|
for len(s) > 0 {
|
||||||
|
if s[0] == '\'' {
|
||||||
|
if len(s) > 1 && s[1] == '\'' {
|
||||||
|
buf.WriteByte('\'')
|
||||||
|
s = s[2:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return buf.String(), s[1:], nil
|
||||||
|
}
|
||||||
|
buf.WriteByte(s[0])
|
||||||
|
s = s[1:]
|
||||||
|
}
|
||||||
|
return "", "", fmt.Errorf("unterminated quoted value")
|
||||||
|
}
|
||||||
|
|
||||||
// Validate validates the configuration
|
// Validate validates the configuration
|
||||||
func (c *YAMLConfig) Validate() error {
|
func (c *YAMLConfig) Validate() error {
|
||||||
if c.Issuer == "" {
|
if c.Issuer == "" {
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ type Controller struct {
|
|||||||
|
|
||||||
expNewNetworkMap bool
|
expNewNetworkMap bool
|
||||||
expNewNetworkMapAIDs map[string]struct{}
|
expNewNetworkMapAIDs map[string]struct{}
|
||||||
|
|
||||||
|
compactedNetworkMap bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
newNetworkMapBuilder = false
|
newNetworkMapBuilder = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
|
||||||
|
compactedNetworkMap = false
|
||||||
|
}
|
||||||
|
|
||||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||||
expIDs := make(map[string]struct{}, len(ids))
|
expIDs := make(map[string]struct{}, len(ids))
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
@@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
holder: types.NewHolder(),
|
holder: types.NewHolder(),
|
||||||
expNewNetworkMap: newNetworkMapBuilder,
|
expNewNetworkMap: newNetworkMapBuilder,
|
||||||
expNewNetworkMapAIDs: expIDs,
|
expNewNetworkMapAIDs: expIDs,
|
||||||
|
|
||||||
|
compactedNetworkMap: compactedNetworkMap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountID):
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountId):
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
} else {
|
} else {
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
account.InjectProxyPolicies(ctx)
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ type AccessLogEntry struct {
|
|||||||
Reason string
|
Reason string
|
||||||
UserId string `gorm:"index"`
|
UserId string `gorm:"index"`
|
||||||
AuthMethodUsed string `gorm:"index"`
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
BytesUpload int64 `gorm:"index"`
|
||||||
|
BytesDownload int64 `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
@@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
|||||||
a.UserId = serviceLog.GetUserId()
|
a.UserId = serviceLog.GetUserId()
|
||||||
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
a.AccountID = serviceLog.GetAccountId()
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
a.BytesUpload = serviceLog.GetBytesUpload()
|
||||||
|
a.BytesDownload = serviceLog.GetBytesDownload()
|
||||||
|
|
||||||
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
@@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
|||||||
AuthMethodUsed: authMethod,
|
AuthMethodUsed: authMethod,
|
||||||
CountryCode: countryCode,
|
CountryCode: countryCode,
|
||||||
CityName: cityName,
|
CityName: cityName,
|
||||||
|
BytesUpload: a.BytesUpload,
|
||||||
|
BytesDownload: a.BytesDownload,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,3 +15,12 @@ type Domain struct {
|
|||||||
Type Type `gorm:"-"`
|
Type Type `gorm:"-"`
|
||||||
Validated bool
|
Validated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EventMeta returns activity event metadata for a domain
|
||||||
|
func (d *Domain) EventMeta() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"domain": d.Domain,
|
||||||
|
"target_cluster": d.TargetCluster,
|
||||||
|
"validated": d.Validated,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,4 +9,5 @@ type Manager interface {
|
|||||||
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||||
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||||
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||||
|
GetClusterDomains() []string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -27,25 +29,25 @@ type store interface {
|
|||||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type proxyURLProvider interface {
|
type proxyManager interface {
|
||||||
GetConnectedProxyURLs() []string
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
store store
|
store store
|
||||||
validator domain.Validator
|
validator domain.Validator
|
||||||
proxyURLProvider proxyURLProvider
|
proxyManager proxyManager
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||||
return Manager{
|
return Manager{
|
||||||
store: store,
|
store: store,
|
||||||
proxyURLProvider: proxyURLProvider,
|
proxyManager: proxyMgr,
|
||||||
validator: domain.Validator{
|
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||||
Resolver: net.DefaultResolver,
|
|
||||||
},
|
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
|
accountManager: accountManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,8 +69,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
|
|
||||||
// Add connected proxy clusters as free domains.
|
// Add connected proxy clusters as free domains.
|
||||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||||
allowList := m.proxyURLAllowList()
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
log.WithFields(log.Fields{
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"accountID": accountID,
|
"accountID": accountID,
|
||||||
"proxyAllowList": allowList,
|
"proxyAllowList": allowList,
|
||||||
}).Debug("getting domains with proxy allow list")
|
}).Debug("getting domains with proxy allow list")
|
||||||
@@ -107,7 +113,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify the target cluster is in the available clusters
|
// Verify the target cluster is in the available clusters
|
||||||
allowList := m.proxyURLAllowList()
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
|
}
|
||||||
clusterValid := false
|
clusterValid := false
|
||||||
for _, cluster := range allowList {
|
for _, cluster := range allowList {
|
||||||
if cluster == targetCluster {
|
if cluster == targetCluster {
|
||||||
@@ -129,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return d, fmt.Errorf("create domain in store: %w", err)
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get domain from store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
// TODO: check for "no records" type error. Because that is a success condition.
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
return fmt.Errorf("delete domain from store: %w", err)
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
|||||||
}).WithError(err).Error("update custom domain in store")
|
}).WithError(err).Error("update custom domain in store")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
|
||||||
} else {
|
} else {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"accountID": accountID,
|
"accountID": accountID,
|
||||||
@@ -221,21 +243,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
// GetClusterDomains returns a list of proxy cluster domains.
|
||||||
// their URLs
|
func (m Manager) GetClusterDomains() []string {
|
||||||
func (m Manager) proxyURLAllowList() []string {
|
if m.proxyManager == nil {
|
||||||
var reverseProxyAddresses []string
|
return nil
|
||||||
if m.proxyURLProvider != nil {
|
|
||||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
|
||||||
}
|
}
|
||||||
return reverseProxyAddresses
|
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
allowList := m.proxyURLAllowList()
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
|
}
|
||||||
if len(allowList) == 0 {
|
if len(allowList) == 0 {
|
||||||
return "", fmt.Errorf("no proxy clusters available")
|
return "", fmt.Errorf("no proxy clusters available")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,539 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
const unknownHostPlaceholder = "unknown"
|
|
||||||
|
|
||||||
// ClusterDeriver derives the proxy cluster from a domain.
|
|
||||||
type ClusterDeriver interface {
|
|
||||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type managerImpl struct {
|
|
||||||
store store.Store
|
|
||||||
accountManager account.Manager
|
|
||||||
permissionsManager permissions.Manager
|
|
||||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
|
||||||
clusterDeriver ClusterDeriver
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager creates a new service manager.
|
|
||||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
|
||||||
return &managerImpl{
|
|
||||||
store: store,
|
|
||||||
accountManager: accountManager,
|
|
||||||
permissionsManager: permissionsManager,
|
|
||||||
proxyGRPCServer: proxyGRPCServer,
|
|
||||||
clusterDeriver: clusterDeriver,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
|
||||||
for _, target := range service.Targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case reverseproxy.TargetTypePeer:
|
|
||||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = peer.IP.String()
|
|
||||||
case reverseproxy.TargetTypeHost:
|
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = resource.Prefix.Addr().String()
|
|
||||||
case reverseproxy.TargetTypeDomain:
|
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = resource.Domain
|
|
||||||
case reverseproxy.TargetTypeSubnet:
|
|
||||||
// For subnets we do not do any lookups on the resource
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
|
||||||
if m.clusterDeriver != nil {
|
|
||||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
|
||||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
|
||||||
}
|
|
||||||
service.ProxyCluster = proxyCluster
|
|
||||||
}
|
|
||||||
|
|
||||||
service.AccountID = accountID
|
|
||||||
service.InitNewRecord()
|
|
||||||
|
|
||||||
if err := service.Auth.HashSecrets(); err != nil {
|
|
||||||
return fmt.Errorf("hash secrets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
keyPair, err := sessionkey.GenerateKeyPair()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("generate session keys: %w", err)
|
|
||||||
}
|
|
||||||
service.SessionPrivateKey = keyPair.PrivateKey
|
|
||||||
service.SessionPublicKey = keyPair.PublicKey
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transaction.CreateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to create service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
|
||||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
|
||||||
if err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
|
||||||
return fmt.Errorf("failed to check existing service: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if existingService != nil && existingService.ID != excludeServiceID {
|
|
||||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := service.Auth.HashSecrets(); err != nil {
|
|
||||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
|
||||||
|
|
||||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type serviceUpdateInfo struct {
|
|
||||||
oldCluster string
|
|
||||||
domainChanged bool
|
|
||||||
serviceEnabledChanged bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
|
||||||
var updateInfo serviceUpdateInfo
|
|
||||||
|
|
||||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
updateInfo.oldCluster = existingService.ProxyCluster
|
|
||||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
|
||||||
|
|
||||||
if updateInfo.domainChanged {
|
|
||||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
service.ProxyCluster = existingService.ProxyCluster
|
|
||||||
}
|
|
||||||
|
|
||||||
m.preserveExistingAuthSecrets(service, existingService)
|
|
||||||
m.preserveServiceMetadata(service, existingService)
|
|
||||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
|
||||||
|
|
||||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("update service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return &updateInfo, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.clusterDeriver != nil {
|
|
||||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
|
||||||
} else {
|
|
||||||
service.ProxyCluster = newCluster
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
|
||||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
|
||||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
|
||||||
service.Auth.PasswordAuth.Password == "" {
|
|
||||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
|
||||||
}
|
|
||||||
|
|
||||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
|
||||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
|
||||||
service.Auth.PinAuth.Pin == "" {
|
|
||||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
|
||||||
service.Meta = existingService.Meta
|
|
||||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
|
||||||
service.SessionPublicKey = existingService.SessionPublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
|
||||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
|
||||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
|
||||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
|
||||||
default:
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
|
||||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
|
||||||
for _, target := range targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case reverseproxy.TargetTypePeer:
|
|
||||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
|
||||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
|
||||||
}
|
|
||||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
|
||||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
|
||||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
var service *reverseproxy.Service
|
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
var err error
|
|
||||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
|
||||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
|
||||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service.Meta.CertificateIssuedAt = time.Now()
|
|
||||||
|
|
||||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
|
||||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service.Meta.Status = string(status)
|
|
||||||
|
|
||||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
|
||||||
return fmt.Errorf("failed to update service status: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return services, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
|
||||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if target == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return target.ServiceID, nil
|
|
||||||
}
|
|
||||||
@@ -1,375 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestInitializeServiceForCreate(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
|
||||||
mgr := &managerImpl{
|
|
||||||
clusterDeriver: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
service := &reverseproxy.Service{
|
|
||||||
Domain: "example.com",
|
|
||||||
Auth: reverseproxy.AuthConfig{},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, accountID, service.AccountID)
|
|
||||||
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
|
||||||
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
|
||||||
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
|
||||||
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
|
||||||
mgr := &managerImpl{
|
|
||||||
clusterDeriver: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
|
||||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
|
||||||
|
|
||||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
|
||||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
|
||||||
|
|
||||||
assert.NoError(t, err1)
|
|
||||||
assert.NoError(t, err2)
|
|
||||||
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
|
||||||
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckDomainAvailable(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
domain string
|
|
||||||
excludeServiceID string
|
|
||||||
setupMock func(*store.MockStore)
|
|
||||||
expectedError bool
|
|
||||||
errorType status.Type
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "domain available - not found",
|
|
||||||
domain: "available.com",
|
|
||||||
excludeServiceID: "",
|
|
||||||
setupMock: func(ms *store.MockStore) {
|
|
||||||
ms.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "available.com").
|
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
|
||||||
},
|
|
||||||
expectedError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "domain already exists",
|
|
||||||
domain: "exists.com",
|
|
||||||
excludeServiceID: "",
|
|
||||||
setupMock: func(ms *store.MockStore) {
|
|
||||||
ms.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
|
||||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
|
||||||
},
|
|
||||||
expectedError: true,
|
|
||||||
errorType: status.AlreadyExists,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "domain exists but excluded (same ID)",
|
|
||||||
domain: "exists.com",
|
|
||||||
excludeServiceID: "service-123",
|
|
||||||
setupMock: func(ms *store.MockStore) {
|
|
||||||
ms.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
|
||||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
|
||||||
},
|
|
||||||
expectedError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "domain exists with different ID",
|
|
||||||
domain: "exists.com",
|
|
||||||
excludeServiceID: "service-456",
|
|
||||||
setupMock: func(ms *store.MockStore) {
|
|
||||||
ms.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
|
||||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
|
||||||
},
|
|
||||||
expectedError: true,
|
|
||||||
errorType: status.AlreadyExists,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "store error (non-NotFound)",
|
|
||||||
domain: "error.com",
|
|
||||||
excludeServiceID: "",
|
|
||||||
setupMock: func(ms *store.MockStore) {
|
|
||||||
ms.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "error.com").
|
|
||||||
Return(nil, errors.New("database error"))
|
|
||||||
},
|
|
||||||
expectedError: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
tt.setupMock(mockStore)
|
|
||||||
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
|
||||||
|
|
||||||
if tt.expectedError {
|
|
||||||
require.Error(t, err)
|
|
||||||
if tt.errorType != 0 {
|
|
||||||
sErr, ok := status.FromError(err)
|
|
||||||
require.True(t, ok, "error should be a status error")
|
|
||||||
assert.Equal(t, tt.errorType, sErr.Type())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
t.Run("empty domain", func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
mockStore.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "").
|
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
|
||||||
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
mockStore.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "test.com").
|
|
||||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
|
||||||
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
|
||||||
|
|
||||||
assert.Error(t, err)
|
|
||||||
sErr, ok := status.FromError(err)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("nil existing service with nil error", func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
mockStore.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
|
||||||
Return(nil, nil)
|
|
||||||
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPersistNewService(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
t.Run("successful service creation with no targets", func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
service := &reverseproxy.Service{
|
|
||||||
ID: "service-123",
|
|
||||||
Domain: "new.com",
|
|
||||||
Targets: []*reverseproxy.Target{},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock ExecuteInTransaction to execute the function immediately
|
|
||||||
mockStore.EXPECT().
|
|
||||||
ExecuteInTransaction(ctx, gomock.Any()).
|
|
||||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
|
||||||
// Create another mock for the transaction
|
|
||||||
txMock := store.NewMockStore(ctrl)
|
|
||||||
txMock.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "new.com").
|
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
|
||||||
txMock.EXPECT().
|
|
||||||
CreateService(ctx, service).
|
|
||||||
Return(nil)
|
|
||||||
|
|
||||||
return fn(txMock)
|
|
||||||
})
|
|
||||||
|
|
||||||
mgr := &managerImpl{store: mockStore}
|
|
||||||
err := mgr.persistNewService(ctx, accountID, service)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("domain already exists", func(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
|
||||||
service := &reverseproxy.Service{
|
|
||||||
ID: "service-123",
|
|
||||||
Domain: "existing.com",
|
|
||||||
Targets: []*reverseproxy.Target{},
|
|
||||||
}
|
|
||||||
|
|
||||||
mockStore.EXPECT().
|
|
||||||
ExecuteInTransaction(ctx, gomock.Any()).
|
|
||||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
|
||||||
txMock := store.NewMockStore(ctrl)
|
|
||||||
txMock.EXPECT().
|
|
||||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
|
||||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
|
||||||
|
|
||||||
return fn(txMock)
|
|
||||||
})
|
|
||||||
|
|
||||||
mgr := &managerImpl{store: mockStore}
|
|
||||||
err := mgr.persistNewService(ctx, accountID, service)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
sErr, ok := status.FromError(err)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
|
|
||||||
t.Run("preserve password when empty", func(t *testing.T) {
|
|
||||||
existing := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "hashed-password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
updated := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
|
||||||
|
|
||||||
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
|
||||||
existing := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PinAuth: &reverseproxy.PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "hashed-pin",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
updated := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PinAuth: &reverseproxy.PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
|
||||||
|
|
||||||
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
|
||||||
existing := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "old-password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
updated := &reverseproxy.Service{
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "new-password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
|
||||||
|
|
||||||
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
|
||||||
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPreserveServiceMetadata(t *testing.T) {
|
|
||||||
mgr := &managerImpl{}
|
|
||||||
|
|
||||||
existing := &reverseproxy.Service{
|
|
||||||
Meta: reverseproxy.ServiceMeta{
|
|
||||||
CertificateIssuedAt: time.Now(),
|
|
||||||
Status: "active",
|
|
||||||
},
|
|
||||||
SessionPrivateKey: "private-key",
|
|
||||||
SessionPublicKey: "public-key",
|
|
||||||
}
|
|
||||||
|
|
||||||
updated := &reverseproxy.Service{
|
|
||||||
Domain: "updated.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr.preserveServiceMetadata(updated, existing)
|
|
||||||
|
|
||||||
assert.Equal(t, existing.Meta, updated.Meta)
|
|
||||||
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
|
||||||
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
|
||||||
}
|
|
||||||
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager defines the interface for proxy operations
|
||||||
|
type Manager interface {
|
||||||
|
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
|
Disconnect(ctx context.Context, proxyID string) error
|
||||||
|
Heartbeat(ctx context.Context, proxyID string) error
|
||||||
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
|
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||||
|
type OIDCValidationConfig struct {
|
||||||
|
Issuer string
|
||||||
|
Audiences []string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAgeSeconds int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Controller is responsible for managing proxy clusters and routing service updates.
|
||||||
|
type Controller interface {
|
||||||
|
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||||
|
GetOIDCValidationConfig() OIDCValidationConfig
|
||||||
|
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||||
|
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||||
|
GetProxiesForCluster(clusterAddr string) []string
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||||
|
type GRPCController struct {
|
||||||
|
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||||
|
// Map of cluster address -> set of proxy IDs
|
||||||
|
clusterProxies sync.Map
|
||||||
|
metrics *metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGRPCController creates a new GRPCController.
|
||||||
|
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
|
||||||
|
m, err := newMetrics(meter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GRPCController{
|
||||||
|
proxyGRPCServer: proxyGRPCServer,
|
||||||
|
metrics: m,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||||
|
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
|
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||||
|
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||||
|
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||||
|
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||||
|
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
if clusterAddr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||||
|
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||||
|
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||||
|
|
||||||
|
c.metrics.IncrementProxyConnectionCount(clusterAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||||
|
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
if clusterAddr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||||
|
proxySet.(*sync.Map).Delete(proxyID)
|
||||||
|
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||||
|
|
||||||
|
c.metrics.DecrementProxyConnectionCount(clusterAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||||
|
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||||
|
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxies []string
|
||||||
|
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||||
|
proxies = append(proxies, key.(string))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return proxies
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// store defines the interface for proxy persistence operations
|
||||||
|
type store interface {
|
||||||
|
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||||
|
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||||
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles all proxy operations
|
||||||
|
type Manager struct {
|
||||||
|
store store
|
||||||
|
metrics *metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new proxy Manager
|
||||||
|
func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||||
|
m, err := newMetrics(meter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
store: store,
|
||||||
|
metrics: m,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect registers a new proxy connection in the database
|
||||||
|
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
|
now := time.Now()
|
||||||
|
p := &proxy.Proxy{
|
||||||
|
ID: proxyID,
|
||||||
|
ClusterAddress: clusterAddress,
|
||||||
|
IPAddress: ipAddress,
|
||||||
|
LastSeen: now,
|
||||||
|
ConnectedAt: &now,
|
||||||
|
Status: "connected",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"proxyID": proxyID,
|
||||||
|
"clusterAddress": clusterAddress,
|
||||||
|
"ipAddress": ipAddress,
|
||||||
|
}).Info("proxy connected")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect marks a proxy as disconnected in the database
|
||||||
|
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
p := &proxy.Proxy{
|
||||||
|
ID: proxyID,
|
||||||
|
Status: "disconnected",
|
||||||
|
DisconnectedAt: &now,
|
||||||
|
LastSeen: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"proxyID": proxyID,
|
||||||
|
}).Info("proxy disconnected")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat updates the proxy's last seen timestamp
|
||||||
|
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||||
|
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.metrics.IncrementProxyHeartbeatCount()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||||
|
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
|
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return addresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||||
|
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||||
|
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
type metrics struct {
|
||||||
|
proxyConnectionCount metric.Int64UpDownCounter
|
||||||
|
serviceUpdateSendCount metric.Int64Counter
|
||||||
|
proxyHeartbeatCount metric.Int64Counter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMetrics(meter metric.Meter) (*metrics, error) {
|
||||||
|
proxyConnectionCount, err := meter.Int64UpDownCounter(
|
||||||
|
"management_proxy_connection_count",
|
||||||
|
metric.WithDescription("Number of active proxy connections"),
|
||||||
|
metric.WithUnit("{connection}"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceUpdateSendCount, err := meter.Int64Counter(
|
||||||
|
"management_proxy_service_update_send_count",
|
||||||
|
metric.WithDescription("Total number of service updates sent to proxies"),
|
||||||
|
metric.WithUnit("{update}"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyHeartbeatCount, err := meter.Int64Counter(
|
||||||
|
"management_proxy_heartbeat_count",
|
||||||
|
metric.WithDescription("Total number of proxy heartbeats received"),
|
||||||
|
metric.WithUnit("{heartbeat}"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &metrics{
|
||||||
|
proxyConnectionCount: proxyConnectionCount,
|
||||||
|
serviceUpdateSendCount: serviceUpdateSendCount,
|
||||||
|
proxyHeartbeatCount: proxyHeartbeatCount,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
|
||||||
|
m.proxyConnectionCount.Add(context.Background(), 1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.String("cluster", clusterAddr),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
|
||||||
|
m.proxyConnectionCount.Add(context.Background(), -1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.String("cluster", clusterAddr),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
|
||||||
|
m.serviceUpdateSendCount.Add(context.Background(), 1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.String("cluster", clusterAddr),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metrics) IncrementProxyHeartbeatCount() {
|
||||||
|
m.proxyHeartbeatCount.Add(context.Background(), 1)
|
||||||
|
}
|
||||||
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ./manager.go
|
||||||
|
|
||||||
|
// Package proxy is a generated GoMock package.
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
time "time"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockManager is a mock of Manager interface.
|
||||||
|
type MockManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||||
|
type MockManagerMockRecorder struct {
|
||||||
|
mock *MockManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockManager creates a new mock instance.
|
||||||
|
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||||
|
mock := &MockManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupStale mocks base method.
|
||||||
|
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupStale indicates an expected call of CleanupStale.
|
||||||
|
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect mocks base method.
|
||||||
|
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect indicates an expected call of Connect.
|
||||||
|
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect mocks base method.
|
||||||
|
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect indicates an expected call of Disconnect.
|
||||||
|
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveClusterAddresses mocks base method.
|
||||||
|
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
|
||||||
|
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat mocks base method.
|
||||||
|
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat indicates an expected call of Heartbeat.
|
||||||
|
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockController is a mock of Controller interface.
|
||||||
|
type MockController struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockControllerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||||
|
type MockControllerMockRecorder struct {
|
||||||
|
mock *MockController
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockController creates a new mock instance.
|
||||||
|
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||||
|
mock := &MockController{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockControllerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig mocks base method.
|
||||||
|
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||||
|
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||||
|
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster mocks base method.
|
||||||
|
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||||
|
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster mocks base method.
|
||||||
|
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||||
|
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster mocks base method.
|
||||||
|
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||||
|
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster mocks base method.
|
||||||
|
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||||
|
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||||
|
}
|
||||||
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Proxy represents a reverse proxy instance
|
||||||
|
type Proxy struct {
|
||||||
|
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||||
|
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||||
|
IPAddress string `gorm:"type:varchar(45)"`
|
||||||
|
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||||
|
ConnectedAt *time.Time
|
||||||
|
DisconnectedAt *time.Time
|
||||||
|
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Proxy) TableName() string {
|
||||||
|
return "proxies"
|
||||||
|
}
|
||||||
@@ -1,463 +0,0 @@
|
|||||||
package reverseproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Operation string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Create Operation = "create"
|
|
||||||
Update Operation = "update"
|
|
||||||
Delete Operation = "delete"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProxyStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
StatusPending ProxyStatus = "pending"
|
|
||||||
StatusActive ProxyStatus = "active"
|
|
||||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
|
||||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
|
||||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
|
||||||
StatusError ProxyStatus = "error"
|
|
||||||
|
|
||||||
TargetTypePeer = "peer"
|
|
||||||
TargetTypeHost = "host"
|
|
||||||
TargetTypeDomain = "domain"
|
|
||||||
TargetTypeSubnet = "subnet"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Target struct {
|
|
||||||
ID uint `gorm:"primaryKey" json:"-"`
|
|
||||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
|
||||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
|
||||||
Path *string `json:"path,omitempty"`
|
|
||||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
|
||||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
|
||||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
|
||||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
|
||||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
|
||||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PasswordAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PINAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Pin string `json:"pin"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BearerAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthConfig struct {
|
|
||||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuthConfig) HashSecrets() error {
|
|
||||||
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
|
||||||
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hash password: %w", err)
|
|
||||||
}
|
|
||||||
a.PasswordAuth.Password = hashedPassword
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
|
||||||
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hash pin: %w", err)
|
|
||||||
}
|
|
||||||
a.PinAuth.Pin = hashedPin
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuthConfig) ClearSecrets() {
|
|
||||||
if a.PasswordAuth != nil {
|
|
||||||
a.PasswordAuth.Password = ""
|
|
||||||
}
|
|
||||||
if a.PinAuth != nil {
|
|
||||||
a.PinAuth.Pin = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCValidationConfig struct {
|
|
||||||
Issuer string
|
|
||||||
Audiences []string
|
|
||||||
KeysLocation string
|
|
||||||
MaxTokenAgeSeconds int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServiceMeta struct {
|
|
||||||
CreatedAt time.Time
|
|
||||||
CertificateIssuedAt time.Time
|
|
||||||
Status string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Service struct {
|
|
||||||
ID string `gorm:"primaryKey"`
|
|
||||||
AccountID string `gorm:"index"`
|
|
||||||
Name string
|
|
||||||
Domain string `gorm:"index"`
|
|
||||||
ProxyCluster string `gorm:"index"`
|
|
||||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
|
||||||
Enabled bool
|
|
||||||
PassHostHeader bool
|
|
||||||
RewriteRedirects bool
|
|
||||||
Auth AuthConfig `gorm:"serializer:json"`
|
|
||||||
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
|
||||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
|
||||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
|
||||||
for _, target := range targets {
|
|
||||||
target.AccountID = accountID
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Service{
|
|
||||||
AccountID: accountID,
|
|
||||||
Name: name,
|
|
||||||
Domain: domain,
|
|
||||||
ProxyCluster: proxyCluster,
|
|
||||||
Targets: targets,
|
|
||||||
Enabled: enabled,
|
|
||||||
}
|
|
||||||
s.InitNewRecord()
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
|
||||||
// Service record. This overwrites any existing ID and Meta fields and should
|
|
||||||
// only be called during initial creation, not for updates.
|
|
||||||
func (s *Service) InitNewRecord() {
|
|
||||||
s.ID = xid.New().String()
|
|
||||||
s.Meta = ServiceMeta{
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
Status: string(StatusPending),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) ToAPIResponse() *api.Service {
|
|
||||||
s.Auth.ClearSecrets()
|
|
||||||
|
|
||||||
authConfig := api.ServiceAuthConfig{}
|
|
||||||
|
|
||||||
if s.Auth.PasswordAuth != nil {
|
|
||||||
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
|
||||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
|
||||||
Password: s.Auth.PasswordAuth.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PinAuth != nil {
|
|
||||||
authConfig.PinAuth = &api.PINAuthConfig{
|
|
||||||
Enabled: s.Auth.PinAuth.Enabled,
|
|
||||||
Pin: s.Auth.PinAuth.Pin,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.BearerAuth != nil {
|
|
||||||
authConfig.BearerAuth = &api.BearerAuthConfig{
|
|
||||||
Enabled: s.Auth.BearerAuth.Enabled,
|
|
||||||
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert internal targets to API targets
|
|
||||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
|
||||||
for _, target := range s.Targets {
|
|
||||||
apiTargets = append(apiTargets, api.ServiceTarget{
|
|
||||||
Path: target.Path,
|
|
||||||
Host: &target.Host,
|
|
||||||
Port: target.Port,
|
|
||||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
|
||||||
TargetId: target.TargetId,
|
|
||||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
|
||||||
Enabled: target.Enabled,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
meta := api.ServiceMeta{
|
|
||||||
CreatedAt: s.Meta.CreatedAt,
|
|
||||||
Status: api.ServiceMetaStatus(s.Meta.Status),
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.Meta.CertificateIssuedAt.IsZero() {
|
|
||||||
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &api.Service{
|
|
||||||
Id: s.ID,
|
|
||||||
Name: s.Name,
|
|
||||||
Domain: s.Domain,
|
|
||||||
Targets: apiTargets,
|
|
||||||
Enabled: s.Enabled,
|
|
||||||
PassHostHeader: &s.PassHostHeader,
|
|
||||||
RewriteRedirects: &s.RewriteRedirects,
|
|
||||||
Auth: authConfig,
|
|
||||||
Meta: meta,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.ProxyCluster != "" {
|
|
||||||
resp.ProxyCluster = &s.ProxyCluster
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
|
||||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
|
||||||
for _, target := range s.Targets {
|
|
||||||
if !target.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make path prefix stripping configurable per-target.
|
|
||||||
// Currently the matching prefix is baked into the target URL path,
|
|
||||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
|
||||||
targetURL := url.URL{
|
|
||||||
Scheme: target.Protocol,
|
|
||||||
Host: target.Host,
|
|
||||||
Path: "/", // TODO: support service path
|
|
||||||
}
|
|
||||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
|
||||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
path := "/"
|
|
||||||
if target.Path != nil {
|
|
||||||
path = *target.Path
|
|
||||||
}
|
|
||||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
|
||||||
Path: path,
|
|
||||||
Target: targetURL.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
auth := &proto.Authentication{
|
|
||||||
SessionKey: s.SessionPublicKey,
|
|
||||||
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
|
||||||
auth.Password = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
|
||||||
auth.Pin = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
|
||||||
auth.Oidc = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.ProxyMapping{
|
|
||||||
Type: operationToProtoType(operation),
|
|
||||||
Id: s.ID,
|
|
||||||
Domain: s.Domain,
|
|
||||||
Path: pathMappings,
|
|
||||||
AuthToken: authToken,
|
|
||||||
Auth: auth,
|
|
||||||
AccountId: s.AccountID,
|
|
||||||
PassHostHeader: s.PassHostHeader,
|
|
||||||
RewriteRedirects: s.RewriteRedirects,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
|
||||||
switch op {
|
|
||||||
case Create:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
|
||||||
case Update:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
|
||||||
case Delete:
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
|
||||||
default:
|
|
||||||
log.Fatalf("unknown operation type: %v", op)
|
|
||||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isDefaultPort reports whether port is the standard default for the given scheme
|
|
||||||
// (443 for https, 80 for http).
|
|
||||||
func isDefaultPort(scheme string, port int) bool {
|
|
||||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
|
||||||
s.Name = req.Name
|
|
||||||
s.Domain = req.Domain
|
|
||||||
s.AccountID = accountID
|
|
||||||
|
|
||||||
targets := make([]*Target, 0, len(req.Targets))
|
|
||||||
for _, apiTarget := range req.Targets {
|
|
||||||
target := &Target{
|
|
||||||
AccountID: accountID,
|
|
||||||
Path: apiTarget.Path,
|
|
||||||
Port: apiTarget.Port,
|
|
||||||
Protocol: string(apiTarget.Protocol),
|
|
||||||
TargetId: apiTarget.TargetId,
|
|
||||||
TargetType: string(apiTarget.TargetType),
|
|
||||||
Enabled: apiTarget.Enabled,
|
|
||||||
}
|
|
||||||
if apiTarget.Host != nil {
|
|
||||||
target.Host = *apiTarget.Host
|
|
||||||
}
|
|
||||||
targets = append(targets, target)
|
|
||||||
}
|
|
||||||
s.Targets = targets
|
|
||||||
|
|
||||||
s.Enabled = req.Enabled
|
|
||||||
|
|
||||||
if req.PassHostHeader != nil {
|
|
||||||
s.PassHostHeader = *req.PassHostHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.RewriteRedirects != nil {
|
|
||||||
s.RewriteRedirects = *req.RewriteRedirects
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.PasswordAuth != nil {
|
|
||||||
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
|
||||||
Enabled: req.Auth.PasswordAuth.Enabled,
|
|
||||||
Password: req.Auth.PasswordAuth.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.PinAuth != nil {
|
|
||||||
s.Auth.PinAuth = &PINAuthConfig{
|
|
||||||
Enabled: req.Auth.PinAuth.Enabled,
|
|
||||||
Pin: req.Auth.PinAuth.Pin,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Auth.BearerAuth != nil {
|
|
||||||
bearerAuth := &BearerAuthConfig{
|
|
||||||
Enabled: req.Auth.BearerAuth.Enabled,
|
|
||||||
}
|
|
||||||
if req.Auth.BearerAuth.DistributionGroups != nil {
|
|
||||||
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
|
||||||
}
|
|
||||||
s.Auth.BearerAuth = bearerAuth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Validate() error {
|
|
||||||
if s.Name == "" {
|
|
||||||
return errors.New("service name is required")
|
|
||||||
}
|
|
||||||
if len(s.Name) > 255 {
|
|
||||||
return errors.New("service name exceeds maximum length of 255 characters")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Domain == "" {
|
|
||||||
return errors.New("service domain is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s.Targets) == 0 {
|
|
||||||
return errors.New("at least one target is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, target := range s.Targets {
|
|
||||||
switch target.TargetType {
|
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
|
||||||
// host field will be ignored
|
|
||||||
case TargetTypeSubnet:
|
|
||||||
if target.Host == "" {
|
|
||||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
|
||||||
}
|
|
||||||
if target.TargetId == "" {
|
|
||||||
return fmt.Errorf("target %d has empty target_id", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) EventMeta() map[string]any {
|
|
||||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Copy() *Service {
|
|
||||||
targets := make([]*Target, len(s.Targets))
|
|
||||||
for i, target := range s.Targets {
|
|
||||||
targetCopy := *target
|
|
||||||
targets[i] = &targetCopy
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Service{
|
|
||||||
ID: s.ID,
|
|
||||||
AccountID: s.AccountID,
|
|
||||||
Name: s.Name,
|
|
||||||
Domain: s.Domain,
|
|
||||||
ProxyCluster: s.ProxyCluster,
|
|
||||||
Targets: targets,
|
|
||||||
Enabled: s.Enabled,
|
|
||||||
PassHostHeader: s.PassHostHeader,
|
|
||||||
RewriteRedirects: s.RewriteRedirects,
|
|
||||||
Auth: s.Auth,
|
|
||||||
Meta: s.Meta,
|
|
||||||
SessionPrivateKey: s.SessionPrivateKey,
|
|
||||||
SessionPublicKey: s.SessionPublicKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
|
||||||
if enc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.SessionPrivateKey != "" {
|
|
||||||
var err error
|
|
||||||
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
|
||||||
if enc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.SessionPrivateKey != "" {
|
|
||||||
var err error
|
|
||||||
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
package reverseproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func validProxy() *Service {
|
|
||||||
return &Service{
|
|
||||||
Name: "test",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_Valid(t *testing.T) {
|
|
||||||
require.NoError(t, validProxy().Validate())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyName(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Name = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyDomain(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Domain = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_NoTargets(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = nil
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_EmptyTargetId(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets[0].TargetId = ""
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_InvalidTargetType(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets[0].TargetType = "invalid"
|
|
||||||
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_ResourceTarget(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = append(rp.Targets, &Target{
|
|
||||||
TargetId: "resource-1",
|
|
||||||
TargetType: TargetTypeHost,
|
|
||||||
Host: "example.org",
|
|
||||||
Port: 443,
|
|
||||||
Protocol: "https",
|
|
||||||
Enabled: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, rp.Validate())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
rp.Targets = append(rp.Targets, &Target{
|
|
||||||
TargetId: "",
|
|
||||||
TargetType: TargetTypePeer,
|
|
||||||
Host: "10.0.0.2",
|
|
||||||
Port: 80,
|
|
||||||
Protocol: "http",
|
|
||||||
Enabled: true,
|
|
||||||
})
|
|
||||||
err := rp.Validate()
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "target 1")
|
|
||||||
assert.Contains(t, err.Error(), "empty target_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsDefaultPort(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
scheme string
|
|
||||||
port int
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"http", 80, true},
|
|
||||||
{"https", 443, true},
|
|
||||||
{"http", 443, false},
|
|
||||||
{"https", 80, false},
|
|
||||||
{"http", 8080, false},
|
|
||||||
{"https", 8443, false},
|
|
||||||
{"http", 0, false},
|
|
||||||
{"https", 0, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
|
||||||
oidcConfig := OIDCValidationConfig{}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
protocol string
|
|
||||||
host string
|
|
||||||
port int
|
|
||||||
wantTarget string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "http with default port 80 omits port",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 80,
|
|
||||||
wantTarget: "http://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https with default port 443 omits port",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 443,
|
|
||||||
wantTarget: "https://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "port 0 omits port",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 0,
|
|
||||||
wantTarget: "http://10.0.0.1/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-default port is included",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 8080,
|
|
||||||
wantTarget: "http://10.0.0.1:8080/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https with non-default port is included",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 8443,
|
|
||||||
wantTarget: "https://10.0.0.1:8443/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "http port 443 is included",
|
|
||||||
protocol: "http",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 443,
|
|
||||||
wantTarget: "http://10.0.0.1:443/",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "https port 80 is included",
|
|
||||||
protocol: "https",
|
|
||||||
host: "10.0.0.1",
|
|
||||||
port: 80,
|
|
||||||
wantTarget: "https://10.0.0.1:80/",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rp := &Service{
|
|
||||||
ID: "test-id",
|
|
||||||
AccountID: "acc-1",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{
|
|
||||||
TargetId: "peer-1",
|
|
||||||
TargetType: TargetTypePeer,
|
|
||||||
Host: tt.host,
|
|
||||||
Port: tt.port,
|
|
||||||
Protocol: tt.protocol,
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
|
||||||
require.Len(t, pm.Path, 1, "should have one path mapping")
|
|
||||||
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
|
||||||
rp := &Service{
|
|
||||||
ID: "test-id",
|
|
||||||
AccountID: "acc-1",
|
|
||||||
Domain: "example.com",
|
|
||||||
Targets: []*Target{
|
|
||||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
|
||||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
|
||||||
require.Len(t, pm.Path, 1)
|
|
||||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
|
||||||
rp := validProxy()
|
|
||||||
tests := []struct {
|
|
||||||
op Operation
|
|
||||||
want proto.ProxyMappingUpdateType
|
|
||||||
}{
|
|
||||||
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
|
||||||
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
|
||||||
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(string(tt.op), func(t *testing.T) {
|
|
||||||
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
|
||||||
assert.Equal(t, tt.want, pm.Type)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_HashSecrets(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config *AuthConfig
|
|
||||||
wantErr bool
|
|
||||||
validate func(*testing.T, *AuthConfig)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "hash password successfully",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "testPassword123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
|
||||||
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
|
||||||
}
|
|
||||||
// Verify the hash can be verified
|
|
||||||
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
|
||||||
t.Errorf("Hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hash PIN successfully",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "123456",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
|
||||||
}
|
|
||||||
// Verify the hash can be verified
|
|
||||||
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
|
||||||
t.Errorf("Hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hash both password and PIN",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "9999",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
|
||||||
t.Errorf("Password not hashed with argon2id")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN not hashed with argon2id")
|
|
||||||
}
|
|
||||||
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
|
||||||
t.Errorf("Password hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
|
||||||
t.Errorf("PIN hash verification failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip disabled password auth",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: false,
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth.Password != "password" {
|
|
||||||
t.Errorf("Disabled password auth should not be hashed")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip empty password",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth.Password != "" {
|
|
||||||
t.Errorf("Empty password should remain empty")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "skip nil password auth",
|
|
||||||
config: &AuthConfig{
|
|
||||||
PasswordAuth: nil,
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "1234",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, config *AuthConfig) {
|
|
||||||
if config.PasswordAuth != nil {
|
|
||||||
t.Errorf("PasswordAuth should remain nil")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
|
||||||
t.Errorf("PIN should still be hashed")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := tt.config.HashSecrets()
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.validate != nil {
|
|
||||||
tt.validate(t, tt.config)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
|
||||||
config := &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "correctPassword",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := config.HashSecrets(); err != nil {
|
|
||||||
t.Fatalf("HashSecrets() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify with wrong password should fail
|
|
||||||
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
|
||||||
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
|
||||||
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
|
||||||
config := &AuthConfig{
|
|
||||||
PasswordAuth: &PasswordAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Password: "hashedPassword",
|
|
||||||
},
|
|
||||||
PinAuth: &PINAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Pin: "hashedPin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config.ClearSecrets()
|
|
||||||
|
|
||||||
if config.PasswordAuth.Password != "" {
|
|
||||||
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
|
||||||
}
|
|
||||||
if config.PinAuth.Pin != "" {
|
|
||||||
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package reverseproxy
|
package service
|
||||||
|
|
||||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,12 +12,17 @@ type Manager interface {
|
|||||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||||
|
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||||
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||||
|
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
|
||||||
|
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||||
|
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||||
|
StartExposeReaper(ctx context.Context)
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: ./interface.go
|
// Source: ./interface.go
|
||||||
|
|
||||||
// Package reverseproxy is a generated GoMock package.
|
// Package service is a generated GoMock package.
|
||||||
package reverseproxy
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
@@ -49,6 +49,35 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateServiceFromPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
|
||||||
|
ret0, _ := ret[0].(*ExposeServiceResponse)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices mocks base method.
|
||||||
|
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteService mocks base method.
|
// DeleteService mocks base method.
|
||||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -181,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenewServiceFromPeer mocks base method.
|
||||||
|
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
|
||||||
|
}
|
||||||
|
|
||||||
// SetCertificateIssuedAt mocks base method.
|
// SetCertificateIssuedAt mocks base method.
|
||||||
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -196,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetStatus mocks base method.
|
// SetStatus mocks base method.
|
||||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -209,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartExposeReaper mocks base method.
|
||||||
|
func (m *MockManager) StartExposeReaper(ctx context.Context) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "StartExposeReaper", ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartExposeReaper indicates an expected call of StartExposeReaper.
|
||||||
|
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopServiceFromPeer mocks base method.
|
||||||
|
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateService mocks base method.
|
// UpdateService mocks base method.
|
||||||
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -6,10 +6,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
@@ -17,11 +17,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type handler struct {
|
type handler struct {
|
||||||
manager reverseproxy.Manager
|
manager rpservice.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterEndpoints registers all service HTTP endpoints.
|
// RegisterEndpoints registers all service HTTP endpoints.
|
||||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||||
h := &handler{
|
h := &handler{
|
||||||
manager: manager,
|
manager: manager,
|
||||||
}
|
}
|
||||||
@@ -72,8 +72,11 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
service := new(reverseproxy.Service)
|
service := new(rpservice.Service)
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
if err = service.Validate(); err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
@@ -130,9 +133,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
service := new(reverseproxy.Service)
|
service := new(rpservice.Service)
|
||||||
service.ID = serviceID
|
service.ID = serviceID
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
if err = service.Validate(); err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math/rand/v2"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
exposeTTL = 90 * time.Second
|
||||||
|
exposeReapInterval = 30 * time.Second
|
||||||
|
maxExposesPerPeer = 10
|
||||||
|
exposeReapBatch = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
type exposeReaper struct {
|
||||||
|
manager *Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
|
||||||
|
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
// start with a random delay
|
||||||
|
rn := rand.IntN(10)
|
||||||
|
time.Sleep(time.Duration(rn) * time.Second)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(exposeReapInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
r.reapExpiredExposes(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
|
||||||
|
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get expired ephemeral services: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, svc := range expired {
|
||||||
|
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
|
||||||
|
|
||||||
|
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||||
|
log.Debugf("service %s was already deleted by another instance", svc.Domain)
|
||||||
|
} else {
|
||||||
|
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReapExpiredExposes(t *testing.T) {
|
||||||
|
mgr, testStore := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Manually expire the service by backdating meta_last_renewed_at
|
||||||
|
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||||
|
|
||||||
|
// Create a non-expired service
|
||||||
|
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8081,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
|
||||||
|
// Expired service should be deleted
|
||||||
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
|
require.Error(t, err, "expired service should be deleted")
|
||||||
|
|
||||||
|
// Non-expired service should remain
|
||||||
|
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
|
||||||
|
require.NoError(t, err, "active service should remain")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReapAlreadyDeletedService(t *testing.T) {
|
||||||
|
mgr, testStore := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||||
|
|
||||||
|
// Delete the service before reaping
|
||||||
|
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reaping should handle the already-deleted service gracefully
|
||||||
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentReapAndRenew(t *testing.T) {
|
||||||
|
mgr, testStore := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := range 5 {
|
||||||
|
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8080 + i,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire all services
|
||||||
|
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, svc := range services {
|
||||||
|
if svc.Source == rpservice.SourceEphemeral {
|
||||||
|
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(0), count, "all expired services should be reaped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenewEphemeralService(t *testing.T) {
|
||||||
|
mgr, _ := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("renew succeeds for active service", func(t *testing.T) {
|
||||||
|
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8082,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
|
||||||
|
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no active expose session")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountAndExistsEphemeralServices(t *testing.T) {
|
||||||
|
mgr, _ := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(0), count)
|
||||||
|
|
||||||
|
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8083,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), count)
|
||||||
|
|
||||||
|
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "service should exist")
|
||||||
|
|
||||||
|
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "non-existent service should not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxExposesPerPeerEnforced(t *testing.T) {
|
||||||
|
mgr, _ := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := range maxExposesPerPeer {
|
||||||
|
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8090 + i,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "expose %d should succeed", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 9999,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReapSkipsRenewedService(t *testing.T) {
|
||||||
|
mgr, testStore := setupIntegrationTest(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||||
|
Port: 8086,
|
||||||
|
Protocol: "http",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Expire the service
|
||||||
|
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||||
|
|
||||||
|
// Renew it before the reaper runs
|
||||||
|
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||||
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
|
||||||
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
|
require.NoError(t, err, "renewed service should survive reaping")
|
||||||
|
}
|
||||||
|
|
||||||
|
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||||
|
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||||
|
t.Helper()
|
||||||
|
svc, err := s.GetServiceByDomain(context.Background(), domain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expired := time.Now().Add(-2 * exposeTTL)
|
||||||
|
svc.Meta.LastRenewedAt = &expired
|
||||||
|
err = s.UpdateService(context.Background(), svc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -0,0 +1,928 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const unknownHostPlaceholder = "unknown"
|
||||||
|
|
||||||
|
// ClusterDeriver derives the proxy cluster from a domain.
|
||||||
|
type ClusterDeriver interface {
|
||||||
|
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||||
|
GetClusterDomains() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
proxyController proxy.Controller
|
||||||
|
clusterDeriver ClusterDeriver
|
||||||
|
exposeReaper *exposeReaper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new service manager.
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||||
|
mgr := &Manager{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
proxyController: proxyController,
|
||||||
|
clusterDeriver: clusterDeriver,
|
||||||
|
}
|
||||||
|
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||||
|
return mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
|
||||||
|
func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||||
|
m.exposeReaper.StartExposeReaper(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case service.TargetTypePeer:
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = peer.IP.String()
|
||||||
|
case service.TargetTypeHost:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Prefix.Addr().String()
|
||||||
|
case service.TargetTypeDomain:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Domain
|
||||||
|
case service.TargetTypeSubnet:
|
||||||
|
// For subnets we do not do any lookups on the resource
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||||
|
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||||
|
}
|
||||||
|
service.ProxyCluster = proxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
service.AccountID = accountID
|
||||||
|
service.InitNewRecord()
|
||||||
|
|
||||||
|
if err := service.Auth.HashSecrets(); err != nil {
|
||||||
|
return fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := sessionkey.GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate session keys: %w", err)
|
||||||
|
}
|
||||||
|
service.SessionPrivateKey = keyPair.PrivateKey
|
||||||
|
service.SessionPublicKey = keyPair.PublicKey
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.CreateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// persistNewEphemeralService creates an ephemeral service inside a single transaction
|
||||||
|
// that also enforces the duplicate and per-peer limit checks atomically.
|
||||||
|
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||||
|
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||||
|
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
// Lock the peer row to serialize concurrent creates for the same peer.
|
||||||
|
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
|
||||||
|
// table returns no rows and acquires no locks, allowing concurrent inserts
|
||||||
|
// to bypass the per-peer limit.
|
||||||
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
|
||||||
|
return fmt.Errorf("lock peer row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check existing expose: %w", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("count peer exposes: %w", err)
|
||||||
|
}
|
||||||
|
if count >= int64(maxExposesPerPeer) {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||||
|
return fmt.Errorf("create service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
|
||||||
|
existingService, err := transaction.GetServiceByDomain(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing service: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingService != nil && existingService.ID != excludeServiceID {
|
||||||
|
return status.Errorf(status.AlreadyExists, "domain already taken")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.Auth.HashSecrets(); err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||||
|
|
||||||
|
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type serviceUpdateInfo struct {
|
||||||
|
oldCluster string
|
||||||
|
domainChanged bool
|
||||||
|
serviceEnabledChanged bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||||
|
var updateInfo serviceUpdateInfo
|
||||||
|
|
||||||
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateInfo.oldCluster = existingService.ProxyCluster
|
||||||
|
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||||
|
|
||||||
|
if updateInfo.domainChanged {
|
||||||
|
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = existingService.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
m.preserveExistingAuthSecrets(service, existingService)
|
||||||
|
m.preserveServiceMetadata(service, existingService)
|
||||||
|
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||||
|
|
||||||
|
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("update service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return &updateInfo, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||||
|
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = newCluster
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||||
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||||
|
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||||
|
service.Auth.PasswordAuth.Password == "" {
|
||||||
|
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||||
|
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||||
|
service.Auth.PinAuth.Pin == "" {
|
||||||
|
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||||
|
service.Meta = existingService.Meta
|
||||||
|
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||||
|
service.SessionPublicKey = existingService.SessionPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||||
|
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||||
|
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||||
|
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||||
|
default:
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||||
|
for _, target := range targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case service.TargetTypePeer:
|
||||||
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||||
|
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var s *service.Service
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete targets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||||
|
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var services []*service.Service
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, svc := range services {
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||||
|
|
||||||
|
for _, svc := range services {
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||||
|
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||||
|
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
service.Meta.CertificateIssuedAt = &now
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||||
|
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.Status = string(status)
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
|
}
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||||
|
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||||
|
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return target.ServiceID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateExposePermission checks whether the peer is allowed to use the expose feature.
|
||||||
|
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
|
||||||
|
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
|
||||||
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||||
|
return status.Errorf(status.Internal, "get account settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !settings.PeerExposeEnabled {
|
||||||
|
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(settings.PeerExposeGroups) == 0 {
|
||||||
|
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
|
||||||
|
return status.Errorf(status.Internal, "get peer groups: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pg := range peerGroupIDs {
|
||||||
|
if slices.Contains(settings.PeerExposeGroups, pg) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateServiceFromPeer creates a service initiated by a peer expose request.
|
||||||
|
// It validates the request, checks expose permissions, enforces the per-peer limit,
|
||||||
|
// creates the service, and tracks it for TTL-based reaping.
|
||||||
|
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||||
|
if err := req.Validate(); err != nil {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceName, err := service.GenerateExposeName(req.NamePrefix)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := req.ToService(accountID, peerID, serviceName)
|
||||||
|
svc.Source = service.SourceEphemeral
|
||||||
|
|
||||||
|
if svc.Domain == "" {
|
||||||
|
domain, err := m.buildRandomDomain(svc.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
|
||||||
|
}
|
||||||
|
svc.Domain = domain
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
|
||||||
|
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
|
||||||
|
}
|
||||||
|
svc.Auth.BearerAuth.DistributionGroups = groupIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.SourcePeer = peerID
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
svc.Meta.LastRenewedAt = &now
|
||||||
|
|
||||||
|
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||||
|
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
|
||||||
|
|
||||||
|
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
|
||||||
|
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return &service.ExposeServiceResponse{
|
||||||
|
ServiceName: svc.Name,
|
||||||
|
ServiceURL: "https://" + svc.Domain,
|
||||||
|
Domain: svc.Domain,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
|
||||||
|
if len(groupNames) == 0 {
|
||||||
|
return []string{}, fmt.Errorf("no group names provided")
|
||||||
|
}
|
||||||
|
groupIDs := make([]string, 0, len(groupNames))
|
||||||
|
for _, groupName := range groupNames {
|
||||||
|
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||||
|
}
|
||||||
|
groupIDs = append(groupIDs, g.ID)
|
||||||
|
}
|
||||||
|
return groupIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||||
|
if m.clusterDeriver == nil {
|
||||||
|
return "", fmt.Errorf("unable to get random domain")
|
||||||
|
}
|
||||||
|
clusterDomains := m.clusterDeriver.GetClusterDomains()
|
||||||
|
if len(clusterDomains) == 0 {
|
||||||
|
return "", fmt.Errorf("no cluster domains found for service %s", name)
|
||||||
|
}
|
||||||
|
index := rand.IntN(len(clusterDomains))
|
||||||
|
domain := name + "." + clusterDomains[index]
|
||||||
|
return domain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
|
||||||
|
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||||
|
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
|
||||||
|
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||||
|
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
|
||||||
|
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
|
||||||
|
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||||
|
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
activityCode := activity.PeerServiceUnexposed
|
||||||
|
if expired {
|
||||||
|
activityCode = activity.PeerServiceExposeExpired
|
||||||
|
}
|
||||||
|
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||||
|
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||||
|
svc, err := m.store.GetServiceByDomain(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Source != service.SourceEphemeral {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.SourcePeer != peerID {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
return svc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||||
|
var svc *service.Service
|
||||||
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Source != service.SourceEphemeral {
|
||||||
|
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.SourcePeer != peerID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("delete service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||||
|
peer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
|
||||||
|
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
|
||||||
|
// that it is still expired under a row lock. This prevents deleting a service
|
||||||
|
// that was renewed between the batch query and this delete, and ensures only one
|
||||||
|
// management instance processes the deletion
|
||||||
|
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||||
|
var svc *service.Service
|
||||||
|
deleted := false
|
||||||
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("delete service: %w", err)
|
||||||
|
}
|
||||||
|
deleted = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !deleted {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||||
|
peer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||||
|
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||||
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
|
||||||
|
if peer == nil {
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
meta["peer_name"] = peer.Name
|
||||||
|
if peer.IP != nil {
|
||||||
|
meta["peer_ip"] = peer.IP.String()
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
817
management/internals/modules/reverseproxy/service/service.go
Normal file
817
management/internals/modules/reverseproxy/service/service.go
Normal file
@@ -0,0 +1,817 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Operation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Create Operation = "create"
|
||||||
|
Update Operation = "update"
|
||||||
|
Delete Operation = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Status string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusPending Status = "pending"
|
||||||
|
StatusActive Status = "active"
|
||||||
|
StatusTunnelNotCreated Status = "tunnel_not_created"
|
||||||
|
StatusCertificatePending Status = "certificate_pending"
|
||||||
|
StatusCertificateFailed Status = "certificate_failed"
|
||||||
|
StatusError Status = "error"
|
||||||
|
|
||||||
|
TargetTypePeer = "peer"
|
||||||
|
TargetTypeHost = "host"
|
||||||
|
TargetTypeDomain = "domain"
|
||||||
|
TargetTypeSubnet = "subnet"
|
||||||
|
|
||||||
|
SourcePermanent = "permanent"
|
||||||
|
SourceEphemeral = "ephemeral"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TargetOptions struct {
|
||||||
|
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||||
|
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||||
|
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||||
|
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Target struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"-"`
|
||||||
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
|
Path *string `json:"path,omitempty"`
|
||||||
|
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||||
|
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||||
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
|
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||||
|
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||||
|
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PasswordAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PINAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Pin string `json:"pin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BearerAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) HashSecrets() error {
|
||||||
|
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||||
|
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
a.PasswordAuth.Password = hashedPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||||
|
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash pin: %w", err)
|
||||||
|
}
|
||||||
|
a.PinAuth.Pin = hashedPin
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) ClearSecrets() {
|
||||||
|
if a.PasswordAuth != nil {
|
||||||
|
a.PasswordAuth.Password = ""
|
||||||
|
}
|
||||||
|
if a.PinAuth != nil {
|
||||||
|
a.PinAuth.Pin = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Meta struct {
|
||||||
|
CreatedAt time.Time
|
||||||
|
CertificateIssuedAt *time.Time
|
||||||
|
Status string
|
||||||
|
LastRenewedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
Domain string `gorm:"type:varchar(255);uniqueIndex"`
|
||||||
|
ProxyCluster string `gorm:"index"`
|
||||||
|
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||||
|
Enabled bool
|
||||||
|
PassHostHeader bool
|
||||||
|
RewriteRedirects bool
|
||||||
|
Auth AuthConfig `gorm:"serializer:json"`
|
||||||
|
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||||
|
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||||
|
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||||
|
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||||
|
for _, target := range targets {
|
||||||
|
target.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: name,
|
||||||
|
Domain: domain,
|
||||||
|
ProxyCluster: proxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: enabled,
|
||||||
|
}
|
||||||
|
s.InitNewRecord()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
|
// Service record. This overwrites any existing ID and Meta fields and should
|
||||||
|
// only be called during initial creation, not for updates.
|
||||||
|
func (s *Service) InitNewRecord() {
|
||||||
|
s.ID = xid.New().String()
|
||||||
|
s.Meta = Meta{
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Status: string(StatusPending),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToAPIResponse() *api.Service {
|
||||||
|
s.Auth.ClearSecrets()
|
||||||
|
|
||||||
|
authConfig := api.ServiceAuthConfig{}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil {
|
||||||
|
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||||
|
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: s.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil {
|
||||||
|
authConfig.PinAuth = &api.PINAuthConfig{
|
||||||
|
Enabled: s.Auth.PinAuth.Enabled,
|
||||||
|
Pin: s.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil {
|
||||||
|
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||||
|
Enabled: s.Auth.BearerAuth.Enabled,
|
||||||
|
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal targets to API targets
|
||||||
|
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
st := api.ServiceTarget{
|
||||||
|
Path: target.Path,
|
||||||
|
Host: &target.Host,
|
||||||
|
Port: target.Port,
|
||||||
|
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||||
|
TargetId: target.TargetId,
|
||||||
|
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||||
|
Enabled: target.Enabled,
|
||||||
|
}
|
||||||
|
st.Options = targetOptionsToAPI(target.Options)
|
||||||
|
apiTargets = append(apiTargets, st)
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := api.ServiceMeta{
|
||||||
|
CreatedAt: s.Meta.CreatedAt,
|
||||||
|
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Meta.CertificateIssuedAt != nil {
|
||||||
|
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &api.Service{
|
||||||
|
Id: s.ID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Targets: apiTargets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: &s.PassHostHeader,
|
||||||
|
RewriteRedirects: &s.RewriteRedirects,
|
||||||
|
Auth: authConfig,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ProxyCluster != "" {
|
||||||
|
resp.ProxyCluster = &s.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
|
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
if !target.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make path prefix stripping configurable per-target.
|
||||||
|
// Currently the matching prefix is baked into the target URL path,
|
||||||
|
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||||
|
targetURL := url.URL{
|
||||||
|
Scheme: target.Protocol,
|
||||||
|
Host: target.Host,
|
||||||
|
Path: "/", // TODO: support service path
|
||||||
|
}
|
||||||
|
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||||
|
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/"
|
||||||
|
if target.Path != nil {
|
||||||
|
path = *target.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := &proto.PathMapping{
|
||||||
|
Path: path,
|
||||||
|
Target: targetURL.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.Options = targetOptionsToProto(target.Options)
|
||||||
|
pathMappings = append(pathMappings, pm)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &proto.Authentication{
|
||||||
|
SessionKey: s.SessionPublicKey,
|
||||||
|
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||||
|
auth.Password = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||||
|
auth.Pin = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
auth.Oidc = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.ProxyMapping{
|
||||||
|
Type: operationToProtoType(operation),
|
||||||
|
Id: s.ID,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Path: pathMappings,
|
||||||
|
AuthToken: authToken,
|
||||||
|
Auth: auth,
|
||||||
|
AccountId: s.AccountID,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||||
|
switch op {
|
||||||
|
case Create:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
case Update:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||||
|
case Delete:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||||
|
default:
|
||||||
|
log.Fatalf("unknown operation type: %v", op)
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||||
|
// (443 for https, 80 for http).
|
||||||
|
func isDefaultPort(scheme string, port int) bool {
|
||||||
|
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||||
|
type PathRewriteMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PathRewritePreserve PathRewriteMode = "preserve"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||||
|
switch mode {
|
||||||
|
case PathRewritePreserve:
|
||||||
|
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
|
||||||
|
default:
|
||||||
|
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||||
|
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiOpts := &api.ServiceTargetOptions{}
|
||||||
|
if opts.SkipTLSVerify {
|
||||||
|
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
s := opts.RequestTimeout.String()
|
||||||
|
apiOpts.RequestTimeout = &s
|
||||||
|
}
|
||||||
|
if opts.PathRewrite != "" {
|
||||||
|
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
|
||||||
|
apiOpts.PathRewrite = &pr
|
||||||
|
}
|
||||||
|
if len(opts.CustomHeaders) > 0 {
|
||||||
|
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||||
|
}
|
||||||
|
return apiOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||||
|
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
popts := &proto.PathTargetOptions{
|
||||||
|
SkipTlsVerify: opts.SkipTLSVerify,
|
||||||
|
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||||
|
CustomHeaders: opts.CustomHeaders,
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||||
|
}
|
||||||
|
return popts
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
|
||||||
|
var opts TargetOptions
|
||||||
|
if o.SkipTlsVerify != nil {
|
||||||
|
opts.SkipTLSVerify = *o.SkipTlsVerify
|
||||||
|
}
|
||||||
|
if o.RequestTimeout != nil {
|
||||||
|
d, err := time.ParseDuration(*o.RequestTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
|
||||||
|
}
|
||||||
|
opts.RequestTimeout = d
|
||||||
|
}
|
||||||
|
if o.PathRewrite != nil {
|
||||||
|
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
|
||||||
|
}
|
||||||
|
if o.CustomHeaders != nil {
|
||||||
|
opts.CustomHeaders = *o.CustomHeaders
|
||||||
|
}
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
|
||||||
|
s.Name = req.Name
|
||||||
|
s.Domain = req.Domain
|
||||||
|
s.AccountID = accountID
|
||||||
|
|
||||||
|
targets := make([]*Target, 0, len(req.Targets))
|
||||||
|
for i, apiTarget := range req.Targets {
|
||||||
|
target := &Target{
|
||||||
|
AccountID: accountID,
|
||||||
|
Path: apiTarget.Path,
|
||||||
|
Port: apiTarget.Port,
|
||||||
|
Protocol: string(apiTarget.Protocol),
|
||||||
|
TargetId: apiTarget.TargetId,
|
||||||
|
TargetType: string(apiTarget.TargetType),
|
||||||
|
Enabled: apiTarget.Enabled,
|
||||||
|
}
|
||||||
|
if apiTarget.Host != nil {
|
||||||
|
target.Host = *apiTarget.Host
|
||||||
|
}
|
||||||
|
if apiTarget.Options != nil {
|
||||||
|
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target.Options = opts
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
s.Targets = targets
|
||||||
|
|
||||||
|
s.Enabled = req.Enabled
|
||||||
|
|
||||||
|
if req.PassHostHeader != nil {
|
||||||
|
s.PassHostHeader = *req.PassHostHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RewriteRedirects != nil {
|
||||||
|
s.RewriteRedirects = *req.RewriteRedirects
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PasswordAuth != nil {
|
||||||
|
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||||
|
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: req.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PinAuth != nil {
|
||||||
|
s.Auth.PinAuth = &PINAuthConfig{
|
||||||
|
Enabled: req.Auth.PinAuth.Enabled,
|
||||||
|
Pin: req.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.BearerAuth != nil {
|
||||||
|
bearerAuth := &BearerAuthConfig{
|
||||||
|
Enabled: req.Auth.BearerAuth.Enabled,
|
||||||
|
}
|
||||||
|
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||||
|
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||||
|
}
|
||||||
|
s.Auth.BearerAuth = bearerAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Validate() error {
|
||||||
|
if s.Name == "" {
|
||||||
|
return errors.New("service name is required")
|
||||||
|
}
|
||||||
|
if len(s.Name) > 255 {
|
||||||
|
return errors.New("service name exceeds maximum length of 255 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Domain == "" {
|
||||||
|
return errors.New("service domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Targets) == 0 {
|
||||||
|
return errors.New("at least one target is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
|
// host field will be ignored
|
||||||
|
case TargetTypeSubnet:
|
||||||
|
if target.Host == "" {
|
||||||
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
if target.TargetId == "" {
|
||||||
|
return fmt.Errorf("target %d has empty target_id", i)
|
||||||
|
}
|
||||||
|
if err := validateTargetOptions(i, &target.Options); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxRequestTimeout = 5 * time.Minute
|
||||||
|
maxCustomHeaders = 16
|
||||||
|
maxHeaderKeyLen = 128
|
||||||
|
maxHeaderValueLen = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||||
|
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
|
||||||
|
|
||||||
|
// hopByHopHeaders are headers that must not be set as custom headers
|
||||||
|
// because they are connection-level and stripped by the proxy.
|
||||||
|
var hopByHopHeaders = map[string]struct{}{
|
||||||
|
"Connection": {},
|
||||||
|
"Keep-Alive": {},
|
||||||
|
"Proxy-Authenticate": {},
|
||||||
|
"Proxy-Authorization": {},
|
||||||
|
"Proxy-Connection": {},
|
||||||
|
"Te": {},
|
||||||
|
"Trailer": {},
|
||||||
|
"Transfer-Encoding": {},
|
||||||
|
"Upgrade": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
|
||||||
|
// and cannot be overridden.
|
||||||
|
var reservedHeaders = map[string]struct{}{
|
||||||
|
"Content-Length": {},
|
||||||
|
"Content-Type": {},
|
||||||
|
"Cookie": {},
|
||||||
|
"Forwarded": {},
|
||||||
|
"X-Forwarded-For": {},
|
||||||
|
"X-Forwarded-Host": {},
|
||||||
|
"X-Forwarded-Port": {},
|
||||||
|
"X-Forwarded-Proto": {},
|
||||||
|
"X-Real-Ip": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||||
|
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
|
||||||
|
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
if opts.RequestTimeout <= 0 {
|
||||||
|
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout > maxRequestTimeout {
|
||||||
|
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomHeaders(idx int, headers map[string]string) error {
|
||||||
|
if len(headers) > maxCustomHeaders {
|
||||||
|
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
|
||||||
|
}
|
||||||
|
seen := make(map[string]string, len(headers))
|
||||||
|
for key, value := range headers {
|
||||||
|
if !httpHeaderNameRe.MatchString(key) {
|
||||||
|
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
|
||||||
|
}
|
||||||
|
if len(key) > maxHeaderKeyLen {
|
||||||
|
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
|
||||||
|
}
|
||||||
|
if len(value) > maxHeaderValueLen {
|
||||||
|
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
|
||||||
|
}
|
||||||
|
if containsCRLF(key) || containsCRLF(value) {
|
||||||
|
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
|
||||||
|
}
|
||||||
|
canonical := http.CanonicalHeaderKey(key)
|
||||||
|
if prev, ok := seen[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
|
||||||
|
}
|
||||||
|
seen[canonical] = key
|
||||||
|
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
|
||||||
|
}
|
||||||
|
if _, ok := reservedHeaders[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
|
||||||
|
}
|
||||||
|
if canonical == "Host" {
|
||||||
|
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsCRLF(s string) bool {
|
||||||
|
return strings.ContainsAny(s, "\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) isAuthEnabled() bool {
|
||||||
|
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Copy() *Service {
|
||||||
|
targets := make([]*Target, len(s.Targets))
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
targetCopy := *target
|
||||||
|
if len(target.Options.CustomHeaders) > 0 {
|
||||||
|
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
|
||||||
|
for k, v := range target.Options.CustomHeaders {
|
||||||
|
targetCopy.Options.CustomHeaders[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
targets[i] = &targetCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
ID: s.ID,
|
||||||
|
AccountID: s.AccountID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
ProxyCluster: s.ProxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
Auth: s.Auth,
|
||||||
|
Meta: s.Meta,
|
||||||
|
SessionPrivateKey: s.SessionPrivateKey,
|
||||||
|
SessionPublicKey: s.SessionPublicKey,
|
||||||
|
Source: s.Source,
|
||||||
|
SourcePeer: s.SourcePeer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||||
|
|
||||||
|
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
|
||||||
|
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
|
||||||
|
|
||||||
|
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
|
||||||
|
type ExposeServiceRequest struct {
|
||||||
|
NamePrefix string
|
||||||
|
Port int
|
||||||
|
Protocol string
|
||||||
|
Domain string
|
||||||
|
Pin string
|
||||||
|
Password string
|
||||||
|
UserGroups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks all fields of the expose request.
|
||||||
|
func (r *ExposeServiceRequest) Validate() error {
|
||||||
|
if r == nil {
|
||||||
|
return errors.New("request cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Port < 1 || r.Port > 65535 {
|
||||||
|
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Protocol != "http" && r.Protocol != "https" {
|
||||||
|
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
|
||||||
|
return errors.New("invalid pin: must be exactly 6 digits")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range r.UserGroups {
|
||||||
|
if g == "" {
|
||||||
|
return errors.New("user group name cannot be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
|
||||||
|
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToService builds a Service from the expose request.
|
||||||
|
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
|
||||||
|
service := &Service{
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: serviceName,
|
||||||
|
Enabled: true,
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
AccountID: accountID,
|
||||||
|
Port: r.Port,
|
||||||
|
Protocol: r.Protocol,
|
||||||
|
TargetId: peerID,
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Domain != "" {
|
||||||
|
service.Domain = serviceName + "." + r.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Pin != "" {
|
||||||
|
service.Auth.PinAuth = &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: r.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Password != "" {
|
||||||
|
service.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: r.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.UserGroups) > 0 {
|
||||||
|
service.Auth.BearerAuth = &BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: r.UserGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExposeServiceResponse contains the result of a successful peer expose creation.
|
||||||
|
type ExposeServiceResponse struct {
|
||||||
|
ServiceName string
|
||||||
|
ServiceURL string
|
||||||
|
Domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateExposeName generates a random service name for peer-exposed services.
|
||||||
|
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
|
||||||
|
func GenerateExposeName(prefix string) (string, error) {
|
||||||
|
if prefix != "" && !validNamePrefix.MatchString(prefix) {
|
||||||
|
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
suffixLen := 12
|
||||||
|
if prefix != "" {
|
||||||
|
suffixLen = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
suffix, err := randomAlphanumeric(suffixLen)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate random name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix == "" {
|
||||||
|
return suffix, nil
|
||||||
|
}
|
||||||
|
return prefix + "-" + suffix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomAlphanumeric(n int) (string, error) {
|
||||||
|
result := make([]byte, n)
|
||||||
|
charsetLen := big.NewInt(int64(len(alphanumCharset)))
|
||||||
|
for i := range result {
|
||||||
|
idx, err := rand.Int(rand.Reader, charsetLen)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
result[i] = alphanumCharset[idx.Int64()]
|
||||||
|
}
|
||||||
|
return string(result), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,732 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validProxy() *Service {
|
||||||
|
return &Service{
|
||||||
|
Name: "test",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Valid(t *testing.T) {
|
||||||
|
require.NoError(t, validProxy().Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyName(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Name = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyDomain(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Domain = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_NoTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = nil
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetId = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetType = "invalid"
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_ResourceTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "resource-1",
|
||||||
|
TargetType: TargetTypeHost,
|
||||||
|
Host: "example.org",
|
||||||
|
Port: 443,
|
||||||
|
Protocol: "https",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.2",
|
||||||
|
Port: 80,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
err := rp.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "target 1")
|
||||||
|
assert.Contains(t, err.Error(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mode PathRewriteMode
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{"empty is default", "", ""},
|
||||||
|
{"preserve is valid", PathRewritePreserve, ""},
|
||||||
|
{"unknown rejected", "regex", "unknown path_rewrite mode"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.PathRewrite = tt.mode
|
||||||
|
err := rp.Validate()
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{"valid 30s", 30 * time.Second, ""},
|
||||||
|
{"valid 2m", 2 * time.Minute, ""},
|
||||||
|
{"zero is fine", 0, ""},
|
||||||
|
{"negative", -1 * time.Second, "must be positive"},
|
||||||
|
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.RequestTimeout = tt.timeout
|
||||||
|
err := rp.Validate()
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
|
||||||
|
t.Run("valid headers", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||||
|
"X-Custom": "value",
|
||||||
|
"X-Trace": "abc123",
|
||||||
|
}
|
||||||
|
assert.NoError(t, rp.Validate())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CRLF in key", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CRLF in value", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid characters")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||||
|
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reserved header rejected", func(t *testing.T) {
|
||||||
|
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Host header rejected", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("too many headers", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
headers := make(map[string]string, 17)
|
||||||
|
for i := range 17 {
|
||||||
|
headers[fmt.Sprintf("X-H%d", i)] = "v"
|
||||||
|
}
|
||||||
|
rp.Targets[0].Options.CustomHeaders = headers
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("key too long", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "key")
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("value too long", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||||
|
"x-custom": "a",
|
||||||
|
"X-Custom": "b",
|
||||||
|
}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "collide")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_TargetOptions(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
Options: TargetOptions{
|
||||||
|
SkipTLSVerify: true,
|
||||||
|
RequestTimeout: 30 * time.Second,
|
||||||
|
PathRewrite: PathRewritePreserve,
|
||||||
|
CustomHeaders: map[string]string{"X-Custom": "val"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
|
||||||
|
opts := pm.Path[0].Options
|
||||||
|
require.NotNil(t, opts, "options should be populated")
|
||||||
|
assert.True(t, opts.SkipTlsVerify)
|
||||||
|
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
|
||||||
|
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
|
||||||
|
require.NotNil(t, opts.RequestTimeout)
|
||||||
|
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDefaultPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
scheme string
|
||||||
|
port int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"http", 80, true},
|
||||||
|
{"https", 443, true},
|
||||||
|
{"http", 443, false},
|
||||||
|
{"https", 80, false},
|
||||||
|
{"http", 8080, false},
|
||||||
|
{"https", 8443, false},
|
||||||
|
{"http", 0, false},
|
||||||
|
{"https", 0, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||||
|
oidcConfig := proxy.OIDCValidationConfig{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "http with default port 80 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with default port 443 omits port",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "https://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port 0 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 0,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-default port is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8080,
|
||||||
|
wantTarget: "http://10.0.0.1:8080/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with non-default port is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8443,
|
||||||
|
wantTarget: "https://10.0.0.1:8443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http port 443 is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "http://10.0.0.1:443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https port 80 is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "https://10.0.0.1:80/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: tt.host,
|
||||||
|
Port: tt.port,
|
||||||
|
Protocol: tt.protocol,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||||
|
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||||
|
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||||
|
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
tests := []struct {
|
||||||
|
op Operation
|
||||||
|
want proto.ProxyMappingUpdateType
|
||||||
|
}{
|
||||||
|
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||||
|
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||||
|
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.op), func(t *testing.T) {
|
||||||
|
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
|
||||||
|
assert.Equal(t, tt.want, pm.Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *AuthConfig
|
||||||
|
wantErr bool
|
||||||
|
validate func(*testing.T, *AuthConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hash password successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "testPassword123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash PIN successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "123456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash both password and PIN",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "9999",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Password hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("PIN hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip disabled password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "password" {
|
||||||
|
t.Errorf("Disabled password auth should not be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip empty password",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Empty password should remain empty")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip nil password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: nil,
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "1234",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth != nil {
|
||||||
|
t.Errorf("PasswordAuth should remain nil")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN should still be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.HashSecrets()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, tt.config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "correctPassword",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.HashSecrets(); err != nil {
|
||||||
|
t.Fatalf("HashSecrets() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify with wrong password should fail
|
||||||
|
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||||
|
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||||
|
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "hashedPassword",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "hashedPin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ClearSecrets()
|
||||||
|
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
if config.PinAuth.Pin != "" {
|
||||||
|
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateExposeName(t *testing.T) {
|
||||||
|
t.Run("no prefix generates 12-char name", func(t *testing.T) {
|
||||||
|
name, err := GenerateExposeName("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, name, 12)
|
||||||
|
assert.Regexp(t, `^[a-z0-9]+$`, name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
|
||||||
|
name, err := GenerateExposeName("myapp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
|
||||||
|
suffix := strings.TrimPrefix(name, "myapp-")
|
||||||
|
assert.Len(t, suffix, 4, "suffix should be 4 chars")
|
||||||
|
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unique names", func(t *testing.T) {
|
||||||
|
names := make(map[string]bool)
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
name, err := GenerateExposeName("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
names[name] = true
|
||||||
|
}
|
||||||
|
assert.Greater(t, len(names), 45, "should generate mostly unique names")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid prefixes", func(t *testing.T) {
|
||||||
|
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
|
||||||
|
for _, prefix := range validPrefixes {
|
||||||
|
name, err := GenerateExposeName(prefix)
|
||||||
|
assert.NoError(t, err, "prefix %q should be valid", prefix)
|
||||||
|
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid prefixes", func(t *testing.T) {
|
||||||
|
invalidPrefixes := []string{
|
||||||
|
"-starts-with-dash",
|
||||||
|
"ends-with-dash-",
|
||||||
|
"has.dots",
|
||||||
|
"HAS-UPPER",
|
||||||
|
"has spaces",
|
||||||
|
"has/slash",
|
||||||
|
"a--",
|
||||||
|
}
|
||||||
|
for _, prefix := range invalidPrefixes {
|
||||||
|
_, err := GenerateExposeName(prefix)
|
||||||
|
assert.Error(t, err, "prefix %q should be invalid", prefix)
|
||||||
|
assert.Contains(t, err.Error(), "invalid name prefix")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||||
|
t.Run("basic HTTP service", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("account-1", "peer-1", "mysvc")
|
||||||
|
|
||||||
|
assert.Equal(t, "account-1", service.AccountID)
|
||||||
|
assert.Equal(t, "mysvc", service.Name)
|
||||||
|
assert.True(t, service.Enabled)
|
||||||
|
assert.Empty(t, service.Domain, "domain should be empty when not specified")
|
||||||
|
require.Len(t, service.Targets, 1)
|
||||||
|
|
||||||
|
target := service.Targets[0]
|
||||||
|
assert.Equal(t, 8080, target.Port)
|
||||||
|
assert.Equal(t, "http", target.Protocol)
|
||||||
|
assert.Equal(t, "peer-1", target.TargetId)
|
||||||
|
assert.Equal(t, TargetTypePeer, target.TargetType)
|
||||||
|
assert.True(t, target.Enabled)
|
||||||
|
assert.Equal(t, "account-1", target.AccountID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with custom domain", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 3000,
|
||||||
|
Domain: "example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("acc", "peer", "web")
|
||||||
|
assert.Equal(t, "web.example.com", service.Domain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with PIN auth", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 80,
|
||||||
|
Pin: "1234",
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("acc", "peer", "svc")
|
||||||
|
require.NotNil(t, service.Auth.PinAuth)
|
||||||
|
assert.True(t, service.Auth.PinAuth.Enabled)
|
||||||
|
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
|
||||||
|
assert.Nil(t, service.Auth.PasswordAuth)
|
||||||
|
assert.Nil(t, service.Auth.BearerAuth)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with password auth", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 80,
|
||||||
|
Password: "secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("acc", "peer", "svc")
|
||||||
|
require.NotNil(t, service.Auth.PasswordAuth)
|
||||||
|
assert.True(t, service.Auth.PasswordAuth.Enabled)
|
||||||
|
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with user groups (bearer auth)", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 80,
|
||||||
|
UserGroups: []string{"admins", "devs"},
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("acc", "peer", "svc")
|
||||||
|
require.NotNil(t, service.Auth.BearerAuth)
|
||||||
|
assert.True(t, service.Auth.BearerAuth.Enabled)
|
||||||
|
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with all auth types", func(t *testing.T) {
|
||||||
|
req := &ExposeServiceRequest{
|
||||||
|
Port: 443,
|
||||||
|
Domain: "myco.com",
|
||||||
|
Pin: "9999",
|
||||||
|
Password: "pass",
|
||||||
|
UserGroups: []string{"ops"},
|
||||||
|
}
|
||||||
|
|
||||||
|
service := req.ToService("acc", "peer", "full")
|
||||||
|
assert.Equal(t, "full.myco.com", service.Domain)
|
||||||
|
require.NotNil(t, service.Auth.PinAuth)
|
||||||
|
require.NotNil(t, service.Auth.PasswordAuth)
|
||||||
|
require.NotNil(t, service.Auth.BearerAuth)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create certificate manager: %v", err)
|
log.Fatalf("failed to create certificate service: %v", err)
|
||||||
}
|
}
|
||||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||||
@@ -152,6 +152,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create management server: %v", err)
|
log.Fatalf("failed to create management server: %v", err)
|
||||||
}
|
}
|
||||||
|
serviceMgr := s.ServiceManager()
|
||||||
|
srv.SetReverseProxyManager(serviceMgr)
|
||||||
|
if serviceMgr != nil {
|
||||||
|
serviceMgr.StartExposeReaper(context.Background())
|
||||||
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||||
|
|
||||||
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||||
@@ -163,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetProxyManager(s.ReverseProxyManager())
|
proxyService.SetServiceManager(s.ServiceManager())
|
||||||
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
})
|
})
|
||||||
return proxyService
|
return proxyService
|
||||||
})
|
})
|
||||||
@@ -188,12 +194,25 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
|||||||
|
|
||||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create proxy token store: %v", err)
|
||||||
|
}
|
||||||
log.Info("One-time token store initialized for proxy authentication")
|
log.Info("One-time token store initialized for proxy authentication")
|
||||||
return tokenStore
|
return tokenStore
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||||
|
return Create(s, func() *nbgrpc.PKCEVerifierStore {
|
||||||
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create PKCE verifier store: %v", err)
|
||||||
|
}
|
||||||
|
return pkceStore
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||||
return Create(s, func() accesslogs.Manager {
|
return Create(s, func() accesslogs.Manager {
|
||||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
@@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ServiceProxyController() proxy.Controller {
|
||||||
|
return Create(s, func() proxy.Controller {
|
||||||
|
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create service proxy controller: %v", err)
|
||||||
|
}
|
||||||
|
return controller
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||||
return Create(s, func() *server.AccountRequestBuffer {
|
return Create(s, func() *server.AccountRequestBuffer {
|
||||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
|||||||
return Create(s, func() account.Manager {
|
return Create(s, func() account.Manager {
|
||||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
accountManager.SetServiceManager(s.ReverseProxyManager())
|
accountManager.SetServiceManager(s.ServiceManager())
|
||||||
})
|
})
|
||||||
|
|
||||||
return accountManager
|
return accountManager
|
||||||
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
return Create(s, func() idp.Manager {
|
return Create(s, func() idp.Manager {
|
||||||
var idpManager idp.Manager
|
var idpManager idp.Manager
|
||||||
var err error
|
var err error
|
||||||
// Use embedded IdP manager if embedded Dex is configured and enabled.
|
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||||
// Legacy IdpManager won't be used anymore even if configured.
|
// Legacy IdpManager won't be used anymore even if configured.
|
||||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create embedded IDP manager: %v", err)
|
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||||
}
|
}
|
||||||
return idpManager
|
return idpManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to external IdP manager
|
// Fall back to external IdP service
|
||||||
if s.Config.IdpManagerConfig != nil {
|
if s.Config.IdpManagerConfig != nil {
|
||||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create IDP manager: %v", err)
|
log.Fatalf("failed to create IDP service: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return idpManager
|
return idpManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
|
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||||
return nil
|
return nil
|
||||||
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||||
return Create(s, func() resources.Manager {
|
return Create(s, func() resources.Manager {
|
||||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
func (s *BaseServer) ServiceManager() service.Manager {
|
||||||
return Create(s, func() reverseproxy.Manager {
|
return Create(s, func() service.Manager {
|
||||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||||
|
return Create(s, func() proxy.Manager {
|
||||||
|
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create proxy manager: %v", err)
|
||||||
|
}
|
||||||
|
return manager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||||
return Create(s, func() *manager.Manager {
|
return Create(s, func() *manager.Manager {
|
||||||
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
|
||||||
return &m
|
return &m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,9 +28,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
const (
|
||||||
// It is used for backward compatibility now.
|
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||||
const ManagementLegacyPort = 33073
|
// It is used for backward compatibility now.
|
||||||
|
ManagementLegacyPort = 33073
|
||||||
|
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||||
|
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||||
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Start(ctx context.Context) error
|
Start(ctx context.Context) error
|
||||||
@@ -58,6 +62,7 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
disableLegacyManagementPort bool
|
disableLegacyManagementPort bool
|
||||||
|
autoResolveDomains bool
|
||||||
|
|
||||||
proxyAuthClose func()
|
proxyAuthClose func()
|
||||||
|
|
||||||
@@ -81,6 +86,7 @@ type Config struct {
|
|||||||
DisableMetrics bool
|
DisableMetrics bool
|
||||||
DisableGeoliteUpdate bool
|
DisableGeoliteUpdate bool
|
||||||
UserDeleteFromIDPEnabled bool
|
UserDeleteFromIDPEnabled bool
|
||||||
|
AutoResolveDomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer initializes and configures a new Server instance
|
// NewServer initializes and configures a new Server instance
|
||||||
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
|
|||||||
mgmtPort: cfg.MgmtPort,
|
mgmtPort: cfg.MgmtPort,
|
||||||
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
||||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||||
|
autoResolveDomains: cfg.AutoResolveDomains,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
s.cancel = cancel
|
s.cancel = cancel
|
||||||
s.errCh = make(chan error, 4)
|
s.errCh = make(chan error, 4)
|
||||||
|
|
||||||
|
if s.autoResolveDomains {
|
||||||
|
s.resolveDomains(srvCtx)
|
||||||
|
}
|
||||||
|
|
||||||
s.PeersManager()
|
s.PeersManager()
|
||||||
s.GeoLocationManager()
|
s.GeoLocationManager()
|
||||||
|
|
||||||
@@ -157,7 +168,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
|
|
||||||
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||||
// before we iterate them. Lazy creation after the loop would miss hooks
|
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||||
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
// registered during GRPCServer() construction (e.g., SetServiceManager).
|
||||||
s.GRPCServer()
|
s.GRPCServer()
|
||||||
|
|
||||||
for _, fn := range s.afterInit {
|
for _, fn := range s.afterInit {
|
||||||
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
s.ReverseProxyGRPCServer().Close()
|
|
||||||
if s.proxyAuthClose != nil {
|
if s.proxyAuthClose != nil {
|
||||||
s.proxyAuthClose()
|
s.proxyAuthClose()
|
||||||
s.proxyAuthClose = nil
|
s.proxyAuthClose = nil
|
||||||
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||||
|
// Fresh installs use the default self-hosted domain, while existing installs reuse the
|
||||||
|
// persisted account domain to keep addressing stable across config changes.
|
||||||
|
func (s *BaseServer) resolveDomains(ctx context.Context) {
|
||||||
|
st := s.Store()
|
||||||
|
|
||||||
|
setDefault := func(logMsg string, args ...any) {
|
||||||
|
if logMsg != "" {
|
||||||
|
log.WithContext(ctx).Warnf(logMsg, args...)
|
||||||
|
}
|
||||||
|
s.dnsDomain = DefaultSelfHostedDomain
|
||||||
|
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
accountsCount, err := st.GetAccountsCounter(ctx)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountsCount == 0 {
|
||||||
|
s.dnsDomain = DefaultSelfHostedDomain
|
||||||
|
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||||
|
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := st.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain == "" {
|
||||||
|
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsDomain = domain
|
||||||
|
s.mgmtSingleAccModeDomain = domain
|
||||||
|
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
|
||||||
|
}
|
||||||
|
|
||||||
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
||||||
installationID := store.GetInstallationID()
|
installationID := store.GetInstallationID()
|
||||||
if installationID != "" {
|
if installationID != "" {
|
||||||
|
|||||||
63
management/internals/server/server_resolve_domains_test.go
Normal file
63
management/internals/server/server_resolve_domains_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
|
||||||
|
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
|
||||||
|
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
|
||||||
|
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
192
management/internals/shared/grpc/expose_service.go
Normal file
192
management/internals/shared/grpc/expose_service.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/encryption"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateExpose handles a peer request to create a new expose service.
|
||||||
|
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
exposeReq := &proto.ExposeServiceRequest{}
|
||||||
|
peerKey, err := s.parseRequest(ctx, req, exposeReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
|
||||||
|
reverseProxyMgr := s.getReverseProxyManager()
|
||||||
|
if reverseProxyMgr == nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
|
||||||
|
NamePrefix: exposeReq.NamePrefix,
|
||||||
|
Port: int(exposeReq.Port),
|
||||||
|
Protocol: exposeProtocolToString(exposeReq.Protocol),
|
||||||
|
Domain: exposeReq.Domain,
|
||||||
|
Pin: exposeReq.Pin,
|
||||||
|
Password: exposeReq.Password,
|
||||||
|
UserGroups: exposeReq.UserGroups,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapExposeError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
|
||||||
|
ServiceName: created.ServiceName,
|
||||||
|
ServiceUrl: created.ServiceURL,
|
||||||
|
Domain: created.Domain,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenewExpose extends the TTL of an active expose session.
|
||||||
|
func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
renewReq := &proto.RenewExposeRequest{}
|
||||||
|
peerKey, err := s.parseRequest(ctx, req, renewReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxyMgr := s.getReverseProxyManager()
|
||||||
|
if reverseProxyMgr == nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
|
||||||
|
return nil, mapExposeError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopExpose terminates an active expose session.
|
||||||
|
func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
stopReq := &proto.StopExposeRequest{}
|
||||||
|
peerKey, err := s.parseRequest(ctx, req, stopReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxyMgr := s.getReverseProxyManager()
|
||||||
|
if reverseProxyMgr == nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
|
||||||
|
return nil, mapExposeError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapExposeError(ctx context.Context, err error) error {
|
||||||
|
s, ok := internalStatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Errorf("expose service error: %v", err)
|
||||||
|
return status.Errorf(codes.Internal, "internal error")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.Type() {
|
||||||
|
case internalStatus.InvalidArgument:
|
||||||
|
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
|
||||||
|
case internalStatus.PermissionDenied:
|
||||||
|
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
|
||||||
|
case internalStatus.NotFound:
|
||||||
|
return status.Errorf(codes.NotFound, "%s", s.Message)
|
||||||
|
case internalStatus.AlreadyExists:
|
||||||
|
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
|
||||||
|
case internalStatus.PreconditionFailed:
|
||||||
|
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
|
||||||
|
default:
|
||||||
|
log.WithContext(ctx).Errorf("expose service error: %v", err)
|
||||||
|
return status.Errorf(codes.Internal, "internal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
|
||||||
|
wgKey, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "internal error")
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "encrypt response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) {
|
||||||
|
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||||
|
if err != nil {
|
||||||
|
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||||
|
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||||
|
}
|
||||||
|
return "", nil, status.Errorf(codes.Internal, "lookup account for peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getReverseProxyManager() rpservice.Manager {
|
||||||
|
s.reverseProxyMu.RLock()
|
||||||
|
defer s.reverseProxyMu.RUnlock()
|
||||||
|
return s.reverseProxyManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReverseProxyManager sets the reverse proxy manager on the server.
|
||||||
|
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
|
||||||
|
s.reverseProxyMu.Lock()
|
||||||
|
defer s.reverseProxyMu.Unlock()
|
||||||
|
s.reverseProxyManager = mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
func exposeProtocolToString(p proto.ExposeProtocol) string {
|
||||||
|
switch p {
|
||||||
|
case proto.ExposeProtocol_EXPOSE_HTTP:
|
||||||
|
return "http"
|
||||||
|
case proto.ExposeProtocol_EXPOSE_HTTPS:
|
||||||
|
return "https"
|
||||||
|
default:
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,28 +1,23 @@
|
|||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/eko/gocache/lib/v4/cache"
|
||||||
|
"github.com/eko/gocache/lib/v4/store"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
|
||||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
|
||||||
// a service is created and must be used exactly once by the proxy
|
|
||||||
// to authenticate a subsequent RPC call.
|
|
||||||
type OneTimeTokenStore struct {
|
|
||||||
tokens map[string]*tokenMetadata
|
|
||||||
mu sync.RWMutex
|
|
||||||
cleanup *time.Ticker
|
|
||||||
cleanupDone chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tokenMetadata stores information about a one-time token
|
|
||||||
type tokenMetadata struct {
|
type tokenMetadata struct {
|
||||||
ServiceID string
|
ServiceID string
|
||||||
AccountID string
|
AccountID string
|
||||||
@@ -30,20 +25,24 @@ type tokenMetadata struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||||
// of expired tokens. The cleanupInterval determines how often expired
|
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||||
// tokens are removed from memory.
|
type OneTimeTokenStore struct {
|
||||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
cache *cache.Cache[string]
|
||||||
store := &OneTimeTokenStore{
|
ctx context.Context
|
||||||
tokens: make(map[string]*tokenMetadata),
|
}
|
||||||
cleanup: time.NewTicker(cleanupInterval),
|
|
||||||
cleanupDone: make(chan struct{}),
|
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||||
|
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||||
|
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start background cleanup goroutine
|
return &OneTimeTokenStore{
|
||||||
go store.cleanupExpired()
|
cache: cache.New[string](cacheStore),
|
||||||
|
ctx: ctx,
|
||||||
return store
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateToken creates a new cryptographically secure one-time token
|
// GenerateToken creates a new cryptographically secure one-time token
|
||||||
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
|||||||
//
|
//
|
||||||
// Returns the generated token string or an error if random generation fails.
|
// Returns the generated token string or an error if random generation fails.
|
||||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
|
||||||
randomBytes := make([]byte, 32)
|
randomBytes := make([]byte, 32)
|
||||||
if _, err := rand.Read(randomBytes); err != nil {
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
|
||||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||||
|
hashedToken := hashToken(token)
|
||||||
|
|
||||||
s.mu.Lock()
|
metadata := &tokenMetadata{
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.tokens[token] = &tokenMetadata{
|
|
||||||
ServiceID: serviceID,
|
ServiceID: serviceID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
ExpiresAt: time.Now().Add(ttl),
|
ExpiresAt: time.Now().Add(ttl),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metadataJSON, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to store token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||||
serviceID, accountID, ttl)
|
serviceID, accountID, ttl)
|
||||||
|
|
||||||
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
|||||||
// - Account ID doesn't match
|
// - Account ID doesn't match
|
||||||
// - Reverse proxy ID doesn't match
|
// - Reverse proxy ID doesn't match
|
||||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||||
s.mu.Lock()
|
hashedToken := hashToken(token)
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
metadata, exists := s.tokens[token]
|
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||||
if !exists {
|
if err != nil {
|
||||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||||
serviceID, accountID)
|
|
||||||
return fmt.Errorf("invalid token")
|
return fmt.Errorf("invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check expiration
|
metadata := &tokenMetadata{}
|
||||||
|
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||||
|
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||||
|
return fmt.Errorf("invalid token metadata")
|
||||||
|
}
|
||||||
|
|
||||||
if time.Now().After(metadata.ExpiresAt) {
|
if time.Now().After(metadata.ExpiresAt) {
|
||||||
delete(s.tokens, token)
|
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
|
||||||
serviceID, accountID)
|
|
||||||
return fmt.Errorf("token expired")
|
return fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
|
||||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||||
metadata.AccountID, accountID)
|
|
||||||
return fmt.Errorf("account ID mismatch")
|
return fmt.Errorf("account ID mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate service ID using constant-time comparison
|
|
||||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||||
metadata.ServiceID, serviceID)
|
|
||||||
return fmt.Errorf("service ID mismatch")
|
return fmt.Errorf("service ID mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete token immediately to enforce single-use
|
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||||
delete(s.tokens, token)
|
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||||
serviceID, accountID)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
func hashToken(token string) string {
|
||||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
hash := sha256.Sum256([]byte(token))
|
||||||
for {
|
return hex.EncodeToString(hash[:])
|
||||||
select {
|
|
||||||
case <-s.cleanup.C:
|
|
||||||
s.mu.Lock()
|
|
||||||
now := time.Now()
|
|
||||||
removed := 0
|
|
||||||
for token, metadata := range s.tokens {
|
|
||||||
if now.After(metadata.ExpiresAt) {
|
|
||||||
delete(s.tokens, token)
|
|
||||||
removed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if removed > 0 {
|
|
||||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
case <-s.cleanupDone:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the cleanup goroutine and releases resources
|
|
||||||
func (s *OneTimeTokenStore) Close() {
|
|
||||||
s.cleanup.Stop()
|
|
||||||
close(s.cleanupDone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
|
||||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return len(s.tokens)
|
|
||||||
}
|
}
|
||||||
|
|||||||
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/eko/gocache/lib/v4/cache"
|
||||||
|
"github.com/eko/gocache/lib/v4/store"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
|
||||||
|
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||||
|
type PKCEVerifierStore struct {
|
||||||
|
cache *cache.Cache[string]
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
|
||||||
|
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
|
||||||
|
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PKCEVerifierStore{
|
||||||
|
cache: cache.New[string](cacheStore),
|
||||||
|
ctx: ctx,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store saves a PKCE verifier associated with an OAuth state parameter.
|
||||||
|
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
|
||||||
|
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
|
||||||
|
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
|
||||||
|
return fmt.Errorf("failed to store PKCE verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
|
||||||
|
// Returns the verifier and true if found, or empty string and false if not found.
|
||||||
|
// This enforces single-use semantics for PKCE verifiers.
|
||||||
|
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
|
||||||
|
verifier, err := s.cache.Get(s.ctx, state)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("PKCE verifier not found for state")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.Delete(s.ctx, state); err != nil {
|
||||||
|
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return verifier, true
|
||||||
|
}
|
||||||
@@ -18,14 +18,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/peer"
|
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
@@ -58,17 +58,17 @@ type ProxyServiceServer struct {
|
|||||||
// Map of connected proxies: proxy_id -> proxy connection
|
// Map of connected proxies: proxy_id -> proxy connection
|
||||||
connectedProxies sync.Map
|
connectedProxies sync.Map
|
||||||
|
|
||||||
// Map of cluster address -> set of proxy IDs
|
|
||||||
clusterProxies sync.Map
|
|
||||||
|
|
||||||
// Channel for broadcasting reverse proxy updates to all proxies
|
|
||||||
updatesChan chan *proto.ProxyMapping
|
|
||||||
|
|
||||||
// Manager for access logs
|
// Manager for access logs
|
||||||
accessLogManager accesslogs.Manager
|
accessLogManager accesslogs.Manager
|
||||||
|
|
||||||
// Manager for reverse proxy operations
|
// Manager for reverse proxy operations
|
||||||
reverseProxyManager reverseproxy.Manager
|
serviceManager rpservice.Manager
|
||||||
|
|
||||||
|
// ProxyController for service updates and cluster management
|
||||||
|
proxyController proxy.Controller
|
||||||
|
|
||||||
|
// Manager for proxy connections
|
||||||
|
proxyManager proxy.Manager
|
||||||
|
|
||||||
// Manager for peers
|
// Manager for peers
|
||||||
peersManager peers.Manager
|
peersManager peers.Manager
|
||||||
@@ -82,84 +82,67 @@ type ProxyServiceServer struct {
|
|||||||
// OIDC configuration for proxy authentication
|
// OIDC configuration for proxy authentication
|
||||||
oidcConfig ProxyOIDCConfig
|
oidcConfig ProxyOIDCConfig
|
||||||
|
|
||||||
// TODO: use database to store these instead?
|
// Store for PKCE verifiers
|
||||||
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
|
pkceVerifierStore *PKCEVerifierStore
|
||||||
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
|
|
||||||
pkceVerifiers sync.Map
|
|
||||||
pkceCleanupCancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const pkceVerifierTTL = 10 * time.Minute
|
const pkceVerifierTTL = 10 * time.Minute
|
||||||
|
|
||||||
type pkceEntry struct {
|
|
||||||
verifier string
|
|
||||||
createdAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
address string
|
address string
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
sendChan chan *proto.ProxyMapping
|
sendChan chan *proto.GetMappingUpdateResponse
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyServiceServer creates a new proxy service server.
|
// NewProxyServiceServer creates a new proxy service server.
|
||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx := context.Background()
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
oidcConfig: oidcConfig,
|
oidcConfig: oidcConfig,
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
pkceCleanupCancel: cancel,
|
proxyManager: proxyMgr,
|
||||||
}
|
}
|
||||||
go s.cleanupPKCEVerifiers(ctx)
|
go s.cleanupStaleProxies(ctx)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
|
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||||
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||||
ticker := time.NewTicker(pkceVerifierTTL)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
now := time.Now()
|
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
|
||||||
s.pkceVerifiers.Range(func(key, value any) bool {
|
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
|
||||||
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
|
|
||||||
s.pkceVerifiers.Delete(key)
|
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops background goroutines.
|
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||||
func (s *ProxyServiceServer) Close() {
|
s.serviceManager = manager
|
||||||
s.pkceCleanupCancel()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
|
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||||
s.reverseProxyManager = manager
|
s.proxyController = proxyController
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMappingUpdate handles the control stream with proxy clients
|
// GetMappingUpdate handles the control stream with proxy clients
|
||||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||||
ctx := stream.Context()
|
ctx := stream.Context()
|
||||||
|
|
||||||
peerInfo := ""
|
peerInfo := PeerIPFromContext(ctx)
|
||||||
if p, ok := peer.FromContext(ctx); ok {
|
|
||||||
peerInfo = p.Addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("New proxy connection from %s", peerInfo)
|
log.Infof("New proxy connection from %s", peerInfo)
|
||||||
|
|
||||||
proxyID := req.GetProxyId()
|
proxyID := req.GetProxyId()
|
||||||
@@ -177,13 +160,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: proxyAddress,
|
address: proxyAddress,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
sendChan: make(chan *proto.ProxyMapping, 100),
|
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||||
ctx: connCtx,
|
ctx: connCtx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
s.addToCluster(conn.address, proxyID)
|
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register proxy in database
|
||||||
|
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
"address": proxyAddress,
|
"address": proxyAddress,
|
||||||
@@ -191,8 +182,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
"total_proxies": len(s.GetConnectedProxies()),
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
}).Info("Proxy registered in cluster")
|
}).Info("Proxy registered in cluster")
|
||||||
defer func() {
|
defer func() {
|
||||||
|
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||||
|
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
s.connectedProxies.Delete(proxyID)
|
s.connectedProxies.Delete(proxyID)
|
||||||
s.removeFromCluster(conn.address, proxyID)
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||||
|
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
log.Infof("Proxy %s disconnected", proxyID)
|
log.Infof("Proxy %s disconnected", proxyID)
|
||||||
}()
|
}()
|
||||||
@@ -204,6 +202,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
go s.sender(conn, errChan)
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
|
// Start heartbeat goroutine
|
||||||
|
go s.heartbeat(connCtx, proxyID)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||||
@@ -212,10 +213,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||||
|
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||||
// Only services matching the proxy's cluster address are sent.
|
// Only services matching the proxy's cluster address are sent.
|
||||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get services from store: %w", err)
|
return fmt.Errorf("get services from store: %w", err)
|
||||||
}
|
}
|
||||||
@@ -224,7 +242,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
return fmt.Errorf("proxy address is invalid")
|
return fmt.Errorf("proxy address is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
var filtered []*reverseproxy.Service
|
var filtered []*rpservice.Service
|
||||||
for _, service := range services {
|
for _, service := range services {
|
||||||
if !service.Enabled {
|
if !service.Enabled {
|
||||||
continue
|
continue
|
||||||
@@ -259,7 +277,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
Mapping: []*proto.ProxyMapping{
|
Mapping: []*proto.ProxyMapping{
|
||||||
service.ToProtoMapping(
|
service.ToProtoMapping(
|
||||||
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
|
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||||
token,
|
token,
|
||||||
s.GetOIDCValidationConfig(),
|
s.GetOIDCValidationConfig(),
|
||||||
),
|
),
|
||||||
@@ -288,7 +306,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-conn.sendChan:
|
case msg := <-conn.sendChan:
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
|
if err := conn.stream.Send(msg); err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -339,7 +357,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
|||||||
// Management should call this when services are created/updated/removed.
|
// Management should call this when services are created/updated/removed.
|
||||||
// For create/update operations a unique one-time auth token is generated per
|
// For create/update operations a unique one-time auth token is generated per
|
||||||
// proxy so that every replica can independently authenticate with management.
|
// proxy so that every replica can independently authenticate with management.
|
||||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||||
conn := value.(*proxyConnection)
|
conn := value.(*proxyConnection)
|
||||||
@@ -349,7 +367,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
|||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
|
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||||
}
|
}
|
||||||
@@ -393,81 +411,75 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
|||||||
return urls
|
return urls
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToCluster registers a proxy in a cluster.
|
|
||||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
|
||||||
if clusterAddr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
|
||||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
|
||||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeFromCluster removes a proxy from a cluster.
|
|
||||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
|
||||||
if clusterAddr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
|
||||||
proxySet.(*sync.Map).Delete(proxyID)
|
|
||||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
||||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||||
// For create/update operations a unique one-time auth token is generated per
|
// For create/update operations a unique one-time auth token is generated per
|
||||||
// proxy so that every replica can independently authenticate with management.
|
// proxy so that every replica can independently authenticate with management.
|
||||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
|
updateResponse := &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: []*proto.ProxyMapping{update},
|
||||||
|
}
|
||||||
|
|
||||||
if clusterAddr == "" {
|
if clusterAddr == "" {
|
||||||
s.SendServiceUpdate(update)
|
s.SendServiceUpdate(updateResponse)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
if s.proxyController == nil {
|
||||||
if !ok {
|
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
|
||||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
|
||||||
|
if len(proxyIDs) == 0 {
|
||||||
|
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
for _, proxyID := range proxyIDs {
|
||||||
proxyID := key.(string)
|
|
||||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||||
conn := connVal.(*proxyConnection)
|
conn := connVal.(*proxyConnection)
|
||||||
msg := s.perProxyMessage(update, proxyID)
|
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return true
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original message is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// returned unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
// Returns nil if token generation fails (the proxy should be skipped).
|
// Returns nil if token generation fails (the proxy should be skipped).
|
||||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
|
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||||
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
|
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
return update
|
for _, mapping := range update.Mapping {
|
||||||
|
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||||
|
resp = append(resp, mapping)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
|
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := shallowCloneMapping(update)
|
msg := shallowCloneMapping(mapping)
|
||||||
msg.AuthToken = token
|
msg.AuthToken = token
|
||||||
return msg
|
resp = append(resp, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: resp,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
||||||
@@ -486,35 +498,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
|
||||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
|
||||||
clusterCounts := make(map[string]int)
|
|
||||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
|
||||||
clusterAddr := key.(string)
|
|
||||||
proxySet := value.(*sync.Map)
|
|
||||||
count := 0
|
|
||||||
proxySet.Range(func(_, _ interface{}) bool {
|
|
||||||
count++
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if count > 0 {
|
|
||||||
clusterCounts[clusterAddr] = count
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
|
||||||
for addr, count := range clusterCounts {
|
|
||||||
clusters = append(clusters, ClusterInfo{
|
|
||||||
Address: addr,
|
|
||||||
ConnectedProxies: count,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return clusters
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||||
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||||
@@ -533,7 +518,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
|
||||||
switch v := req.GetRequest().(type) {
|
switch v := req.GetRequest().(type) {
|
||||||
case *proto.AuthenticateRequest_Pin:
|
case *proto.AuthenticateRequest_Pin:
|
||||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||||
@@ -544,7 +529,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||||
if auth == nil || !auth.Enabled {
|
if auth == nil || !auth.Enabled {
|
||||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||||
return false, "", ""
|
return false, "", ""
|
||||||
@@ -558,7 +543,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
|
|||||||
return true, "pin-user", proxyauth.MethodPIN
|
return true, "pin-user", proxyauth.MethodPIN
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||||
if auth == nil || !auth.Enabled {
|
if auth == nil || !auth.Enabled {
|
||||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||||
return false, "", ""
|
return false, "", ""
|
||||||
@@ -580,7 +565,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||||
if !authenticated || service.SessionPrivateKey == "" {
|
if !authenticated || service.SessionPrivateKey == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -620,7 +605,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
|||||||
}
|
}
|
||||||
|
|
||||||
if certificateIssued {
|
if certificateIssued {
|
||||||
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||||
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
|
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
|
||||||
}
|
}
|
||||||
@@ -632,7 +617,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
|||||||
|
|
||||||
internalStatus := protoStatusToInternal(protoStatus)
|
internalStatus := protoStatusToInternal(protoStatus)
|
||||||
|
|
||||||
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||||
}
|
}
|
||||||
@@ -647,22 +632,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
|||||||
}
|
}
|
||||||
|
|
||||||
// protoStatusToInternal maps proto status to internal status
|
// protoStatusToInternal maps proto status to internal status
|
||||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
|
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||||
switch protoStatus {
|
switch protoStatus {
|
||||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||||
return reverseproxy.StatusPending
|
return rpservice.StatusPending
|
||||||
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
|
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
|
||||||
return reverseproxy.StatusActive
|
return rpservice.StatusActive
|
||||||
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
|
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
|
||||||
return reverseproxy.StatusTunnelNotCreated
|
return rpservice.StatusTunnelNotCreated
|
||||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
|
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
|
||||||
return reverseproxy.StatusCertificatePending
|
return rpservice.StatusCertificatePending
|
||||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
|
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
|
||||||
return reverseproxy.StatusCertificateFailed
|
return rpservice.StatusCertificateFailed
|
||||||
case proto.ProxyStatus_PROXY_STATUS_ERROR:
|
case proto.ProxyStatus_PROXY_STATUS_ERROR:
|
||||||
return reverseproxy.StatusError
|
return rpservice.StatusError
|
||||||
default:
|
default:
|
||||||
return reverseproxy.StatusError
|
return rpservice.StatusError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -727,7 +712,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
|||||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||||
}
|
}
|
||||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
|
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
||||||
@@ -771,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
|||||||
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
||||||
|
|
||||||
codeVerifier := oauth2.GenerateVerifier()
|
codeVerifier := oauth2.GenerateVerifier()
|
||||||
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
|
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
|
||||||
|
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.GetOIDCURLResponse{
|
return &proto.GetOIDCURLResponse{
|
||||||
Url: (&oauth2.Config{
|
Url: (&oauth2.Config{
|
||||||
@@ -790,8 +778,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
|
|||||||
|
|
||||||
// GetOIDCValidationConfig returns the OIDC configuration for token validation
|
// GetOIDCValidationConfig returns the OIDC configuration for token validation
|
||||||
// in the format needed by ToProtoMapping.
|
// in the format needed by ToProtoMapping.
|
||||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
|
func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||||
return reverseproxy.OIDCValidationConfig{
|
return proxy.OIDCValidationConfig{
|
||||||
Issuer: s.oidcConfig.Issuer,
|
Issuer: s.oidcConfig.Issuer,
|
||||||
Audiences: []string{s.oidcConfig.Audience},
|
Audiences: []string{s.oidcConfig.Audience},
|
||||||
KeysLocation: s.oidcConfig.KeysLocation,
|
KeysLocation: s.oidcConfig.KeysLocation,
|
||||||
@@ -808,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
|
|||||||
// ValidateState validates the state parameter from an OAuth callback.
|
// ValidateState validates the state parameter from an OAuth callback.
|
||||||
// Returns the original redirect URL if valid, or an error if invalid.
|
// Returns the original redirect URL if valid, or an error if invalid.
|
||||||
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
||||||
v, ok := s.pkceVerifiers.LoadAndDelete(state)
|
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", "", errors.New("no verifier for state")
|
return "", "", errors.New("no verifier for state")
|
||||||
}
|
}
|
||||||
entry, ok := v.(pkceEntry)
|
|
||||||
if !ok {
|
|
||||||
return "", "", errors.New("invalid verifier for state")
|
|
||||||
}
|
|
||||||
if time.Since(entry.createdAt) > pkceVerifierTTL {
|
|
||||||
return "", "", errors.New("PKCE verifier expired")
|
|
||||||
}
|
|
||||||
verifier = entry.verifier
|
|
||||||
|
|
||||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||||
parts := strings.Split(state, "|")
|
parts := strings.Split(state, "|")
|
||||||
@@ -850,12 +830,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
|||||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||||
// Find the service by domain to get its signing key
|
// Find the service by domain to get its signing key
|
||||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("get services: %w", err)
|
return "", fmt.Errorf("get services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var service *reverseproxy.Service
|
var service *rpservice.Service
|
||||||
for _, svc := range services {
|
for _, svc := range services {
|
||||||
if svc.Domain == domain {
|
if svc.Domain == domain {
|
||||||
service = svc
|
service = svc
|
||||||
@@ -921,8 +901,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
|
|||||||
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
|
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
|
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account services: %w", err)
|
return nil, fmt.Errorf("get account services: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1043,8 +1023,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
|
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get services: %w", err)
|
return nil, fmt.Errorf("get services: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1058,7 +1038,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
|
|||||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
|
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||||
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
clientIP := peerIPFromContext(ctx)
|
clientIP := PeerIPFromContext(ctx)
|
||||||
|
|
||||||
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||||
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||||
|
|||||||
@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
|
|||||||
l.cancel()
|
l.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerIPFromContext extracts the client IP from the gRPC context.
|
// PeerIPFromContext extracts the client IP from the gRPC context.
|
||||||
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||||
func peerIPFromContext(ctx context.Context) clientIP {
|
func PeerIPFromContext(ctx context.Context) string {
|
||||||
if addr, ok := realip.FromContext(ctx); ok {
|
if addr, ok := realip.FromContext(ctx); ok {
|
||||||
return addr.String()
|
return addr.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,40 +8,44 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockReverseProxyManager struct {
|
type mockReverseProxyManager struct {
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
proxiesByAccount map[string][]*service.Service
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
return nil, m.err
|
return nil, m.err
|
||||||
}
|
}
|
||||||
return m.proxiesByAccount[accountID], nil
|
return m.proxiesByAccount[accountID], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||||
return []*reverseproxy.Service{}, nil
|
return []*service.Service{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
|
||||||
return &reverseproxy.Service{}, nil
|
return &service.Service{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||||
return &reverseproxy.Service{}, nil
|
return &service.Service{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||||
return &reverseproxy.Service{}, nil
|
return &service.Service{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||||
@@ -52,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,14 +68,28 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
|
||||||
return &reverseproxy.Service{}, nil
|
return &service.Service{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||||
|
return &service.ExposeServiceResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
type mockUsersManager struct {
|
type mockUsersManager struct {
|
||||||
users map[string]*types.User
|
users map[string]*types.User
|
||||||
err error
|
err error
|
||||||
@@ -93,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
domain string
|
domain string
|
||||||
userID string
|
userID string
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
proxiesByAccount map[string][]*service.Service
|
||||||
users map[string]*types.User
|
users map[string]*types.User
|
||||||
proxyErr error
|
proxyErr error
|
||||||
userErr error
|
userErr error
|
||||||
@@ -104,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "user not found",
|
name: "user not found",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "unknown-user",
|
userID: "unknown-user",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
},
|
},
|
||||||
users: map[string]*types.User{},
|
users: map[string]*types.User{},
|
||||||
@@ -115,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "proxy not found in user's account",
|
name: "proxy not found in user's account",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
proxiesByAccount: map[string][]*service.Service{},
|
||||||
users: map[string]*types.User{
|
users: map[string]*types.User{
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
},
|
},
|
||||||
@@ -126,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "proxy exists in different account - not accessible",
|
name: "proxy exists in different account - not accessible",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||||
},
|
},
|
||||||
users: map[string]*types.User{
|
users: map[string]*types.User{
|
||||||
@@ -139,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "no bearer auth configured - same account allows access",
|
name: "no bearer auth configured - same account allows access",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
|
||||||
},
|
},
|
||||||
users: map[string]*types.User{
|
users: map[string]*types.User{
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
@@ -151,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "bearer auth disabled - same account allows access",
|
name: "bearer auth disabled - same account allows access",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{
|
"account1": {{
|
||||||
Domain: "app.example.com",
|
Domain: "app.example.com",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
BearerAuth: &service.BearerAuthConfig{Enabled: false},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
@@ -169,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{
|
"account1": {{
|
||||||
Domain: "app.example.com",
|
Domain: "app.example.com",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{},
|
DistributionGroups: []string{},
|
||||||
},
|
},
|
||||||
@@ -190,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "user not in allowed groups",
|
name: "user not in allowed groups",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{
|
"account1": {{
|
||||||
Domain: "app.example.com",
|
Domain: "app.example.com",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
},
|
},
|
||||||
@@ -212,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "user in one of the allowed groups - allow access",
|
name: "user in one of the allowed groups - allow access",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{
|
"account1": {{
|
||||||
Domain: "app.example.com",
|
Domain: "app.example.com",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
},
|
},
|
||||||
@@ -233,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "user in all allowed groups - allow access",
|
name: "user in all allowed groups - allow access",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{
|
"account1": {{
|
||||||
Domain: "app.example.com",
|
Domain: "app.example.com",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
},
|
},
|
||||||
@@ -266,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
name: "multiple proxies in account - finds correct one",
|
name: "multiple proxies in account - finds correct one",
|
||||||
domain: "app2.example.com",
|
domain: "app2.example.com",
|
||||||
userID: "user1",
|
userID: "user1",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {
|
"account1": {
|
||||||
{Domain: "app1.example.com", AccountID: "account1"},
|
{Domain: "app1.example.com", AccountID: "account1"},
|
||||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
|
||||||
{Domain: "app3.example.com", AccountID: "account1"},
|
{Domain: "app3.example.com", AccountID: "account1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -283,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
server := &ProxyServiceServer{
|
server := &ProxyServiceServer{
|
||||||
reverseProxyManager: &mockReverseProxyManager{
|
serviceManager: &mockReverseProxyManager{
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
err: tt.proxyErr,
|
err: tt.proxyErr,
|
||||||
},
|
},
|
||||||
@@ -310,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
accountID string
|
accountID string
|
||||||
domain string
|
domain string
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
proxiesByAccount map[string][]*service.Service
|
||||||
err error
|
err error
|
||||||
expectProxy bool
|
expectProxy bool
|
||||||
expectErr bool
|
expectErr bool
|
||||||
@@ -319,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
name: "proxy found",
|
name: "proxy found",
|
||||||
accountID: "account1",
|
accountID: "account1",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {
|
"account1": {
|
||||||
{Domain: "other.example.com", AccountID: "account1"},
|
{Domain: "other.example.com", AccountID: "account1"},
|
||||||
{Domain: "app.example.com", AccountID: "account1"},
|
{Domain: "app.example.com", AccountID: "account1"},
|
||||||
@@ -332,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
name: "proxy not found in account",
|
name: "proxy not found in account",
|
||||||
accountID: "account1",
|
accountID: "account1",
|
||||||
domain: "unknown.example.com",
|
domain: "unknown.example.com",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
proxiesByAccount: map[string][]*service.Service{
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
},
|
},
|
||||||
expectProxy: false,
|
expectProxy: false,
|
||||||
@@ -342,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
name: "empty proxy list for account",
|
name: "empty proxy list for account",
|
||||||
accountID: "account1",
|
accountID: "account1",
|
||||||
domain: "app.example.com",
|
domain: "app.example.com",
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
proxiesByAccount: map[string][]*service.Service{},
|
||||||
expectProxy: false,
|
expectProxy: false,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
},
|
},
|
||||||
@@ -360,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
server := &ProxyServiceServer{
|
server := &ProxyServiceServer{
|
||||||
reverseProxyManager: &mockReverseProxyManager{
|
serviceManager: &mockReverseProxyManager{
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
err: tt.err,
|
err: tt.err,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,13 +12,65 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type testProxyController struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
clusterProxies map[string]map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestProxyController() *testProxyController {
|
||||||
|
return &testProxyController{
|
||||||
|
clusterProxies: make(map[string]map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||||
|
return proxy.OIDCValidationConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if _, ok := c.clusterProxies[clusterAddr]; !ok {
|
||||||
|
c.clusterProxies[clusterAddr] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
|
||||||
|
delete(proxies, proxyID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
proxies, ok := c.clusterProxies[clusterAddr]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(proxies))
|
||||||
|
for id := range proxies {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||||
// and returns the channel where messages will be received.
|
// and returns the channel where messages will be received.
|
||||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||||
ch := make(chan *proto.ProxyMapping, 10)
|
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: clusterAddr,
|
address: clusterAddr,
|
||||||
@@ -25,13 +78,12 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
|||||||
}
|
}
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
|
||||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
|
||||||
|
|
||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case msg := <-ch:
|
||||||
return msg
|
return msg
|
||||||
@@ -41,24 +93,29 @@ func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
ctx := context.Background()
|
||||||
defer tokenStore.Close()
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
const cluster = "proxy.example.com"
|
||||||
const numProxies = 3
|
const numProxies = 3
|
||||||
|
|
||||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
|
||||||
for i := range numProxies {
|
for i := range numProxies {
|
||||||
id := "proxy-" + string(rune('a'+i))
|
id := "proxy-" + string(rune('a'+i))
|
||||||
channels[i] = registerFakeProxy(s, id, cluster)
|
channels[i] = registerFakeProxy(s, id, cluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
@@ -68,14 +125,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||||
|
|
||||||
tokens := make([]string, numProxies)
|
tokens := make([]string, numProxies)
|
||||||
for i, ch := range channels {
|
for i, ch := range channels {
|
||||||
msg := drainChannel(ch)
|
resp := drainChannel(ch)
|
||||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
||||||
assert.Equal(t, update.Domain, msg.Domain)
|
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
||||||
assert.Equal(t, update.Id, msg.Id)
|
msg := resp.Mapping[0]
|
||||||
|
assert.Equal(t, mapping.Domain, msg.Domain)
|
||||||
|
assert.Equal(t, mapping.Id, msg.Id)
|
||||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||||
tokens[i] = msg.AuthToken
|
tokens[i] = msg.AuthToken
|
||||||
}
|
}
|
||||||
@@ -96,66 +155,84 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
ctx := context.Background()
|
||||||
defer tokenStore.Close()
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
const cluster = "proxy.example.com"
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
Domain: "test.example.com",
|
Domain: "test.example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
resp1 := drainChannel(ch1)
|
||||||
msg2 := drainChannel(ch2)
|
resp2 := drainChannel(ch2)
|
||||||
require.NotNil(t, msg1)
|
require.NotNil(t, resp1)
|
||||||
require.NotNil(t, msg2)
|
require.NotNil(t, resp2)
|
||||||
|
require.Len(t, resp1.Mapping, 1)
|
||||||
|
require.Len(t, resp2.Mapping, 1)
|
||||||
|
|
||||||
// Delete operations should not generate tokens
|
// Delete operations should not generate tokens
|
||||||
assert.Empty(t, msg1.AuthToken)
|
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
||||||
assert.Empty(t, msg2.AuthToken)
|
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
||||||
|
|
||||||
// No tokens should have been created
|
|
||||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
ctx := context.Background()
|
||||||
defer tokenStore.Close()
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
Domain: "test.example.com",
|
Domain: "test.example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
update := &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: []*proto.ProxyMapping{mapping},
|
||||||
|
}
|
||||||
|
|
||||||
s.SendServiceUpdate(update)
|
s.SendServiceUpdate(update)
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
resp1 := drainChannel(ch1)
|
||||||
msg2 := drainChannel(ch2)
|
resp2 := drainChannel(ch2)
|
||||||
require.NotNil(t, msg1)
|
require.NotNil(t, resp1)
|
||||||
require.NotNil(t, msg2)
|
require.NotNil(t, resp2)
|
||||||
|
require.Len(t, resp1.Mapping, 1)
|
||||||
|
require.Len(t, resp2.Mapping, 1)
|
||||||
|
|
||||||
|
msg1 := resp1.Mapping[0]
|
||||||
|
msg2 := resp2.Mapping[0]
|
||||||
|
|
||||||
assert.NotEmpty(t, msg1.AuthToken)
|
assert.NotEmpty(t, msg1.AuthToken)
|
||||||
assert.NotEmpty(t, msg2.AuthToken)
|
assert.NotEmpty(t, msg2.AuthToken)
|
||||||
@@ -178,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := "https://app.example.com/callback"
|
redirectURL := "https://app.example.com/callback"
|
||||||
@@ -202,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Old format had only 2 parts: base64(url)|hmac
|
// Old format had only 2 parts: base64(url)|hmac
|
||||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, err := s.ValidateState("base64url|hmac")
|
_, _, err = s.ValidateState("base64url|hmac")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "invalid state format")
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store with tampered HMAC
|
// Store with tampered HMAC
|
||||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "invalid state signature")
|
assert.Contains(t, err.Error(), "invalid state signature")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
@@ -80,6 +81,9 @@ type Server struct {
|
|||||||
syncSem atomic.Int32
|
syncSem atomic.Int32
|
||||||
syncLimEnabled bool
|
syncLimEnabled bool
|
||||||
syncLim int32
|
syncLim int32
|
||||||
|
|
||||||
|
reverseProxyManager rpservice.Manager
|
||||||
|
reverseProxyMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -34,11 +34,18 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|||||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
serviceManager := &testValidateSessionServiceManager{store: testStore}
|
||||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||||
|
proxyManager := &testValidateSessionProxyManager{}
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||||
proxyService.SetProxyManager(proxyManager)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||||
|
proxyService.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
createTestProxies(t, ctx, testStore)
|
||||||
|
|
||||||
@@ -54,7 +61,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
|||||||
|
|
||||||
pubKey, privKey := generateSessionKeyPair(t)
|
pubKey, privKey := generateSessionKeyPair(t)
|
||||||
|
|
||||||
testProxy := &reverseproxy.Service{
|
testProxy := &service.Service{
|
||||||
ID: "testProxyId",
|
ID: "testProxyId",
|
||||||
AccountID: "testAccountId",
|
AccountID: "testAccountId",
|
||||||
Name: "Test Proxy",
|
Name: "Test Proxy",
|
||||||
@@ -62,15 +69,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
SessionPrivateKey: privKey,
|
SessionPrivateKey: privKey,
|
||||||
SessionPublicKey: pubKey,
|
SessionPublicKey: pubKey,
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
restrictedProxy := &reverseproxy.Service{
|
restrictedProxy := &service.Service{
|
||||||
ID: "restrictedProxyId",
|
ID: "restrictedProxyId",
|
||||||
AccountID: "testAccountId",
|
AccountID: "testAccountId",
|
||||||
Name: "Restricted Proxy",
|
Name: "Restricted Proxy",
|
||||||
@@ -78,8 +85,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
SessionPrivateKey: privKey,
|
SessionPrivateKey: privKey,
|
||||||
SessionPublicKey: pubKey,
|
SessionPublicKey: pubKey,
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"allowedGroupId"},
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
},
|
},
|
||||||
@@ -196,7 +203,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
||||||
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
assert.Equal(t, "service_not_found", resp.DeniedReason)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateSession_InvalidToken(t *testing.T) {
|
func TestValidateSession_InvalidToken(t *testing.T) {
|
||||||
@@ -239,62 +246,102 @@ func TestValidateSession_MissingToken(t *testing.T) {
|
|||||||
assert.Contains(t, resp.DeniedReason, "missing")
|
assert.Contains(t, resp.DeniedReason, "missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
type testValidateSessionProxyManager struct {
|
type testValidateSessionServiceManager struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
|
type testValidateSessionProxyManager struct{}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testValidateSessionUsersManager struct {
|
type testValidateSessionUsersManager struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
reverseProxyManager reverseproxy.Manager
|
serviceManager service.Manager
|
||||||
|
|
||||||
// config contains the management server configuration
|
// config contains the management server configuration
|
||||||
config *nbconfig.Config
|
config *nbconfig.Config
|
||||||
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||||
am.reverseProxyManager = serviceManager
|
am.serviceManager = serviceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func isUniqueConstraintError(err error) bool {
|
func isUniqueConstraintError(err error) bool {
|
||||||
@@ -376,6 +376,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
|
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -394,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||||
}
|
}
|
||||||
if reloadReverseProxy {
|
if reloadReverseProxy {
|
||||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -492,6 +493,21 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||||
|
oldEnabled := oldSettings.PeerExposeEnabled
|
||||||
|
newEnabled := newSettings.PeerExposeEnabled
|
||||||
|
|
||||||
|
if oldEnabled == newEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
event := activity.AccountPeerExposeEnabled
|
||||||
|
if !newEnabled {
|
||||||
|
event = activity.AccountPeerExposeDisabled
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||||
if newSettings.PeerInactivityExpirationEnabled {
|
if newSettings.PeerInactivityExpirationEnabled {
|
||||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
@@ -714,6 +730,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, otherUser := range account.Users {
|
for _, otherUser := range account.Users {
|
||||||
if otherUser.Id == userID {
|
if otherUser.Id == userID {
|
||||||
continue
|
continue
|
||||||
@@ -1358,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
// We override incoming domain claims to group users under a single account.
|
// We override incoming domain claims to group users under a single account.
|
||||||
userAuth.Domain = am.singleAccountModeDomain
|
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
|
||||||
userAuth.DomainCategory = types.PrivateCategory
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||||
@@ -1393,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
return accountID, user.Id, nil
|
return accountID, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
|
||||||
|
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
|
||||||
|
userAuth.DomainCategory = types.PrivateCategory
|
||||||
|
userAuth.Domain = am.singleAccountModeDomain
|
||||||
|
|
||||||
|
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userAuth.Domain = domain
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package account
|
package account
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -61,11 +63,11 @@ type Manager interface {
|
|||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||||
@@ -140,5 +142,5 @@ type Manager interface {
|
|||||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
SetServiceManager(serviceManager service.Manager)
|
||||||
}
|
}
|
||||||
|
|||||||
1738
management/server/account/manager_mock.go
Normal file
1738
management/server/account/manager_mock.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,10 +15,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/prometheus/client_golang/prometheus/push"
|
"github.com/prometheus/client_golang/prometheus/push"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -27,10 +29,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/cache"
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
@@ -1802,12 +1807,12 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Address: "172.12.6.1/24",
|
Address: "172.12.6.1/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Services: []*reverseproxy.Service{
|
Services: []*service.Service{
|
||||||
{
|
{
|
||||||
ID: "service1",
|
ID: "service1",
|
||||||
Name: "test-service",
|
Name: "test-service",
|
||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
Targets: []*reverseproxy.Target{},
|
Targets: []*service.Target{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||||
@@ -3112,6 +3117,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
|
||||||
|
proxyManager := proxy.NewMockManager(ctrl)
|
||||||
|
proxyManager.EXPECT().
|
||||||
|
CleanupStale(gomock.Any(), gomock.Any()).
|
||||||
|
Return(nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
@@ -3122,7 +3133,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||||
|
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
|
||||||
|
|
||||||
return manager, updateManager, nil
|
return manager, updateManager, nil
|
||||||
}
|
}
|
||||||
@@ -3951,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
|||||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
|
||||||
|
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("real-domain.com", "private", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "real-domain.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.NotFound, "no accounts"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.Internal, "db down"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "db down")
|
||||||
|
// Defaults should still be set before error path
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("", "", status.Errorf(status.Internal, "query failed"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "query failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -208,6 +208,25 @@ const (
|
|||||||
ServiceUpdated Activity = 109
|
ServiceUpdated Activity = 109
|
||||||
ServiceDeleted Activity = 110
|
ServiceDeleted Activity = 110
|
||||||
|
|
||||||
|
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
|
||||||
|
PeerServiceExposed Activity = 111
|
||||||
|
// PeerServiceUnexposed indicates that a peer-exposed service was removed
|
||||||
|
PeerServiceUnexposed Activity = 112
|
||||||
|
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
|
||||||
|
PeerServiceExposeExpired Activity = 113
|
||||||
|
|
||||||
|
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
|
||||||
|
AccountPeerExposeEnabled Activity = 114
|
||||||
|
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
|
||||||
|
AccountPeerExposeDisabled Activity = 115
|
||||||
|
|
||||||
|
// DomainAdded indicates that a user added a custom domain
|
||||||
|
DomainAdded Activity = 116
|
||||||
|
// DomainDeleted indicates that a user deleted a custom domain
|
||||||
|
DomainDeleted Activity = 117
|
||||||
|
// DomainValidated indicates that a custom domain was validated
|
||||||
|
DomainValidated Activity = 118
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -345,6 +364,17 @@ var activityMap = map[Activity]Code{
|
|||||||
ServiceCreated: {"Service created", "service.create"},
|
ServiceCreated: {"Service created", "service.create"},
|
||||||
ServiceUpdated: {"Service updated", "service.update"},
|
ServiceUpdated: {"Service updated", "service.update"},
|
||||||
ServiceDeleted: {"Service deleted", "service.delete"},
|
ServiceDeleted: {"Service deleted", "service.delete"},
|
||||||
|
|
||||||
|
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
|
||||||
|
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
|
||||||
|
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
|
||||||
|
|
||||||
|
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
||||||
|
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
||||||
|
|
||||||
|
DomainAdded: {"Domain added", "domain.add"},
|
||||||
|
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||||
|
DomainValidated: {"Domain validated", "domain.validate"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
|||||||
|
|
||||||
switch storeEngine {
|
switch storeEngine {
|
||||||
case types.SqliteStoreEngine:
|
case types.SqliteStoreEngine:
|
||||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
dbFile := eventSinkDB
|
||||||
|
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
|
||||||
|
dbFile = envFile
|
||||||
|
}
|
||||||
|
connStr := dbFile
|
||||||
|
if !filepath.IsAbs(dbFile) {
|
||||||
|
connStr = filepath.Join(dataDir, dbFile)
|
||||||
|
}
|
||||||
|
dialector = sqlite.Open(connStr)
|
||||||
case types.PostgresStoreEngine:
|
case types.PostgresStoreEngine:
|
||||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
var groupIDsToDelete []string
|
var groupIDsToDelete []string
|
||||||
var deletedGroups []*types.Group
|
var deletedGroups []*types.Group
|
||||||
|
|
||||||
|
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||||
allErrors = errors.Join(allErrors, err)
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
|
||||||
// disable a deleting integration group if the initiator is not an admin service user
|
// disable a deleting integration group if the initiator is not an admin service user
|
||||||
if group.Issued == types.GroupIssuedIntegration {
|
if group.Issued == types.GroupIssuedIntegration {
|
||||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
|||||||
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if slices.Contains(flowGroups, group.ID) {
|
||||||
|
return &GroupLinkError{"settings", "traffic event logging"}
|
||||||
|
}
|
||||||
|
|
||||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -26,6 +27,7 @@ import (
|
|||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
|
||||||
|
am, _, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
settingsMock := settings.NewMockManager(ctrl)
|
||||||
|
settingsMock.EXPECT().
|
||||||
|
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
settingsMock.EXPECT().
|
||||||
|
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(false, nil).
|
||||||
|
AnyTimes()
|
||||||
|
am.settingsManager = settingsMock
|
||||||
|
|
||||||
|
_, account, err := initTestGroupAccount(am)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
grp := &types.Group{
|
||||||
|
ID: "grp-for-flow",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: "Group for flow",
|
||||||
|
Issued: types.GroupIssuedAPI,
|
||||||
|
Peers: make([]string, 0),
|
||||||
|
}
|
||||||
|
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
|
||||||
|
|
||||||
|
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var gErr *GroupLinkError
|
||||||
|
require.ErrorAs(t, err, &gErr)
|
||||||
|
assert.Equal(t, "settings", gErr.Resource)
|
||||||
|
assert.Equal(t, "traffic event logging", gErr.Name)
|
||||||
|
|
||||||
|
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, group)
|
||||||
|
|
||||||
|
regularGrp := &types.Group{
|
||||||
|
ID: "grp-regular",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: "Regular group",
|
||||||
|
Issued: types.GroupIssuedAPI,
|
||||||
|
Peers: make([]string, 0),
|
||||||
|
}
|
||||||
|
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, group)
|
||||||
|
|
||||||
|
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
|
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
|
||||||
accountID := "testingAcc"
|
accountID := "testingAcc"
|
||||||
domain := "example.com"
|
domain := "example.com"
|
||||||
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||||
permissionsManager := permissions.NewManager(manager.Store)
|
permissionsManager := permissions.NewManager(manager.Store)
|
||||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -73,7 +73,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
idp.AddEndpoints(accountManager, router)
|
idp.AddEndpoints(accountManager, router)
|
||||||
instance.AddEndpoints(instanceManager, router)
|
instance.AddEndpoints(instanceManager, router)
|
||||||
instance.AddVersionEndpoint(instanceManager, router)
|
instance.AddVersionEndpoint(instanceManager, router)
|
||||||
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register OAuth callback handler for proxy authentication
|
// Register OAuth callback handler for proxy authentication
|
||||||
|
|||||||
@@ -168,6 +168,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
||||||
|
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
|
||||||
|
}
|
||||||
|
|
||||||
returnSettings := &types.Settings{
|
returnSettings := &types.Settings{
|
||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
@@ -175,6 +179,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
|||||||
|
|
||||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||||
|
|
||||||
|
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
|
||||||
|
PeerExposeGroups: req.Settings.PeerExposeGroups,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Settings.Extra != nil {
|
if req.Settings.Extra != nil {
|
||||||
@@ -336,6 +343,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
JwtAllowGroups: &jwtAllowGroups,
|
JwtAllowGroups: &jwtAllowGroups,
|
||||||
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||||
|
PeerExposeEnabled: settings.PeerExposeEnabled,
|
||||||
|
PeerExposeGroups: settings.PeerExposeGroups,
|
||||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||||
DnsDomain: &settings.DNSDomain,
|
DnsDomain: &settings.DNSDomain,
|
||||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -190,7 +190,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
|
|
||||||
oidcServer := newFakeOIDCServer()
|
oidcServer := newFakeOIDCServer()
|
||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
usersManager := users.NewManager(testStore)
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
@@ -205,12 +209,14 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
proxyService := nbgrpc.NewProxyServiceServer(
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
&testAccessLogManager{},
|
&testAccessLogManager{},
|
||||||
tokenStore,
|
tokenStore,
|
||||||
|
pkceStore,
|
||||||
oidcConfig,
|
oidcConfig,
|
||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
proxyService.SetProxyManager(&testServiceManager{store: testStore})
|
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||||
|
|
||||||
handler := NewAuthCallbackHandler(proxyService, nil)
|
handler := NewAuthCallbackHandler(proxyService, nil)
|
||||||
|
|
||||||
@@ -239,12 +245,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||||
|
|
||||||
testProxy := &reverseproxy.Service{
|
testProxy := &service.Service{
|
||||||
ID: "testProxyId",
|
ID: "testProxyId",
|
||||||
AccountID: "testAccountId",
|
AccountID: "testAccountId",
|
||||||
Name: "Test Proxy",
|
Name: "Test Proxy",
|
||||||
Domain: "test-proxy.example.com",
|
Domain: "test-proxy.example.com",
|
||||||
Targets: []*reverseproxy.Target{{
|
Targets: []*service.Target{{
|
||||||
Path: strPtr("/"),
|
Path: strPtr("/"),
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
@@ -254,8 +260,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}},
|
}},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"allowedGroupId"},
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
},
|
},
|
||||||
@@ -265,12 +271,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
}
|
}
|
||||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
restrictedProxy := &reverseproxy.Service{
|
restrictedProxy := &service.Service{
|
||||||
ID: "restrictedProxyId",
|
ID: "restrictedProxyId",
|
||||||
AccountID: "testAccountId",
|
AccountID: "testAccountId",
|
||||||
Name: "Restricted Proxy",
|
Name: "Restricted Proxy",
|
||||||
Domain: "restricted-proxy.example.com",
|
Domain: "restricted-proxy.example.com",
|
||||||
Targets: []*reverseproxy.Target{{
|
Targets: []*service.Target{{
|
||||||
Path: strPtr("/"),
|
Path: strPtr("/"),
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
@@ -280,8 +286,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}},
|
}},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
DistributionGroups: []string{"restrictedGroupId"},
|
DistributionGroups: []string{"restrictedGroupId"},
|
||||||
},
|
},
|
||||||
@@ -291,12 +297,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
}
|
}
|
||||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||||
|
|
||||||
noAuthProxy := &reverseproxy.Service{
|
noAuthProxy := &service.Service{
|
||||||
ID: "noAuthProxyId",
|
ID: "noAuthProxyId",
|
||||||
AccountID: "testAccountId",
|
AccountID: "testAccountId",
|
||||||
Name: "No Auth Proxy",
|
Name: "No Auth Proxy",
|
||||||
Domain: "no-auth-proxy.example.com",
|
Domain: "no-auth-proxy.example.com",
|
||||||
Targets: []*reverseproxy.Target{{
|
Targets: []*service.Target{{
|
||||||
Path: strPtr("/"),
|
Path: strPtr("/"),
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
@@ -306,8 +312,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}},
|
}},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Auth: reverseproxy.AuthConfig{
|
Auth: service.AuthConfig{
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -357,19 +363,23 @@ type testServiceManager struct {
|
|||||||
store store.Store
|
store store.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +391,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -393,15 +403,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,6 +419,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,13 @@ import (
|
|||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
@@ -91,12 +94,28 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
||||||
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
|
if err != nil {
|
||||||
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
|
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||||
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
|
}
|
||||||
proxyServiceServer.SetProxyManager(reverseProxyManager)
|
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
am.SetServiceManager(reverseProxyManager)
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PKCE verifier store: %v", err)
|
||||||
|
}
|
||||||
|
noopMeter := noop.NewMeterProvider().Meter("")
|
||||||
|
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
|
}
|
||||||
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||||
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||||
|
}
|
||||||
|
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
|
||||||
|
proxyServiceServer.SetServiceManager(serviceManager)
|
||||||
|
am.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||||
@@ -114,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
|
|||||||
|
|
||||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||||
type EmbeddedStorageConfig struct {
|
type EmbeddedStorageConfig struct {
|
||||||
// Type is the storage type (currently only "sqlite3" is supported)
|
// Type is the storage type: "sqlite3" (default) or "postgres"
|
||||||
Type string
|
Type string
|
||||||
// Config contains type-specific configuration
|
// Config contains type-specific configuration
|
||||||
Config EmbeddedStorageTypeConfig
|
Config EmbeddedStorageTypeConfig
|
||||||
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
|
|||||||
type EmbeddedStorageTypeConfig struct {
|
type EmbeddedStorageTypeConfig struct {
|
||||||
// File is the path to the SQLite database file (for sqlite3 type)
|
// File is the path to the SQLite database file (for sqlite3 type)
|
||||||
File string
|
File string
|
||||||
|
// DSN is the connection string for postgres
|
||||||
|
DSN string
|
||||||
}
|
}
|
||||||
|
|
||||||
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
|
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
|
||||||
@@ -74,6 +76,22 @@ type OwnerConfig struct {
|
|||||||
Username string
|
Username string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
|
||||||
|
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
|
||||||
|
switch storageType {
|
||||||
|
case "sqlite3":
|
||||||
|
return map[string]interface{}{
|
||||||
|
"file": cfg.File,
|
||||||
|
}, nil
|
||||||
|
case "postgres":
|
||||||
|
return map[string]interface{}{
|
||||||
|
"dsn": cfg.DSN,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
|
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
|
||||||
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||||
if c.Issuer == "" {
|
if c.Issuer == "" {
|
||||||
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
|||||||
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
|
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
|
||||||
return nil, fmt.Errorf("storage file is required for sqlite3")
|
return nil, fmt.Errorf("storage file is required for sqlite3")
|
||||||
}
|
}
|
||||||
|
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
|
||||||
|
return nil, fmt.Errorf("storage DSN is required for postgres")
|
||||||
|
}
|
||||||
|
|
||||||
|
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||||
cliRedirectURIs := c.CLIRedirectURIs
|
cliRedirectURIs := c.CLIRedirectURIs
|
||||||
@@ -101,9 +127,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
|||||||
Issuer: c.Issuer,
|
Issuer: c.Issuer,
|
||||||
Storage: dex.Storage{
|
Storage: dex.Storage{
|
||||||
Type: c.Storage.Type,
|
Type: c.Storage.Type,
|
||||||
Config: map[string]interface{}{
|
Config: storageConfig,
|
||||||
"file": c.Storage.Config.File,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
Web: dex.Web{
|
Web: dex.Web{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: []string{"*"},
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/netbirdio/netbird/idp/dex"
|
"github.com/netbirdio/netbird/idp/dex"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -51,6 +52,7 @@ type properties map[string]interface{}
|
|||||||
type DataSource interface {
|
type DataSource interface {
|
||||||
GetAllAccounts(ctx context.Context) []*types.Account
|
GetAllAccounts(ctx context.Context) []*types.Account
|
||||||
GetStoreEngine() types.Engine
|
GetStoreEngine() types.Engine
|
||||||
|
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnManager peer connection manager that holds state for current active connections
|
// ConnManager peer connection manager that holds state for current active connections
|
||||||
@@ -210,6 +212,17 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
rosenpassEnabled int
|
rosenpassEnabled int
|
||||||
localUsers int
|
localUsers int
|
||||||
idpUsers int
|
idpUsers int
|
||||||
|
embeddedIdpTypes map[string]int
|
||||||
|
services int
|
||||||
|
servicesEnabled int
|
||||||
|
servicesTargets int
|
||||||
|
servicesStatusActive int
|
||||||
|
servicesStatusPending int
|
||||||
|
servicesStatusError int
|
||||||
|
servicesTargetType map[string]int
|
||||||
|
servicesAuthPassword int
|
||||||
|
servicesAuthPin int
|
||||||
|
servicesAuthOIDC int
|
||||||
)
|
)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
metricsProperties := make(properties)
|
metricsProperties := make(properties)
|
||||||
@@ -218,10 +231,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
rulesProtocol = make(map[string]int)
|
rulesProtocol = make(map[string]int)
|
||||||
rulesDirection = make(map[string]int)
|
rulesDirection = make(map[string]int)
|
||||||
activeUsersLastDay = make(map[string]struct{})
|
activeUsersLastDay = make(map[string]struct{})
|
||||||
|
embeddedIdpTypes = make(map[string]int)
|
||||||
|
servicesTargetType = make(map[string]int)
|
||||||
uptime = time.Since(w.startupTime).Seconds()
|
uptime = time.Since(w.startupTime).Seconds()
|
||||||
connections := w.connManager.GetAllConnectedPeers()
|
connections := w.connManager.GetAllConnectedPeers()
|
||||||
version = nbversion.NetbirdVersion()
|
version = nbversion.NetbirdVersion()
|
||||||
|
|
||||||
|
customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx)
|
||||||
|
|
||||||
for _, account := range w.dataSource.GetAllAccounts(ctx) {
|
for _, account := range w.dataSource.GetAllAccounts(ctx) {
|
||||||
accounts++
|
accounts++
|
||||||
|
|
||||||
@@ -278,6 +295,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
} else {
|
} else {
|
||||||
idpUsers++
|
idpUsers++
|
||||||
}
|
}
|
||||||
|
idpType := extractIdpType(idpID)
|
||||||
|
embeddedIdpTypes[idpType]++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -331,6 +350,37 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
|
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, service := range account.Services {
|
||||||
|
services++
|
||||||
|
if service.Enabled {
|
||||||
|
servicesEnabled++
|
||||||
|
}
|
||||||
|
servicesTargets += len(service.Targets)
|
||||||
|
|
||||||
|
switch rpservice.Status(service.Meta.Status) {
|
||||||
|
case rpservice.StatusActive:
|
||||||
|
servicesStatusActive++
|
||||||
|
case rpservice.StatusPending:
|
||||||
|
servicesStatusPending++
|
||||||
|
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
|
||||||
|
servicesStatusError++
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
servicesTargetType[target.TargetType]++
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled {
|
||||||
|
servicesAuthPassword++
|
||||||
|
}
|
||||||
|
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled {
|
||||||
|
servicesAuthPin++
|
||||||
|
}
|
||||||
|
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||||
|
servicesAuthOIDC++
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||||
@@ -369,6 +419,27 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
||||||
metricsProperties["local_users_count"] = localUsers
|
metricsProperties["local_users_count"] = localUsers
|
||||||
metricsProperties["idp_users_count"] = idpUsers
|
metricsProperties["idp_users_count"] = idpUsers
|
||||||
|
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
|
||||||
|
|
||||||
|
metricsProperties["services"] = services
|
||||||
|
metricsProperties["services_enabled"] = servicesEnabled
|
||||||
|
metricsProperties["services_targets"] = servicesTargets
|
||||||
|
metricsProperties["services_status_active"] = servicesStatusActive
|
||||||
|
metricsProperties["services_status_pending"] = servicesStatusPending
|
||||||
|
metricsProperties["services_status_error"] = servicesStatusError
|
||||||
|
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||||
|
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||||
|
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||||
|
metricsProperties["custom_domains"] = customDomains
|
||||||
|
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||||
|
|
||||||
|
for targetType, count := range servicesTargetType {
|
||||||
|
metricsProperties["services_target_type_"+targetType] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
for idpType, count := range embeddedIdpTypes {
|
||||||
|
metricsProperties["embedded_idp_users_"+idpType] = count
|
||||||
|
}
|
||||||
|
|
||||||
for protocol, count := range rulesProtocol {
|
for protocol, count := range rulesProtocol {
|
||||||
metricsProperties["rules_protocol_"+protocol] = count
|
metricsProperties["rules_protocol_"+protocol] = count
|
||||||
@@ -456,6 +527,20 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
|
|||||||
return req, cancel, nil
|
return req, cancel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractIdpType extracts the IdP type from a Dex connector ID.
|
||||||
|
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
|
||||||
|
// Returns the type prefix, or "oidc" if no known prefix is found.
|
||||||
|
func extractIdpType(connectorID string) string {
|
||||||
|
if connectorID == "local" {
|
||||||
|
return "local"
|
||||||
|
}
|
||||||
|
idx := strings.LastIndex(connectorID, "-")
|
||||||
|
if idx <= 0 {
|
||||||
|
return "oidc"
|
||||||
|
}
|
||||||
|
return strings.ToLower(connectorID[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
func getMinMaxVersion(inputList []string) (string, string) {
|
func getMinMaxVersion(inputList []string) (string, string) {
|
||||||
versions := make([]*version.Version, 0)
|
versions := make([]*version.Version, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/idp/dex"
|
"github.com/netbirdio/netbird/idp/dex"
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -27,7 +28,8 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
|||||||
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
||||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||||
localUserID := dex.EncodeDexUserID("10", "local")
|
localUserID := dex.EncodeDexUserID("10", "local")
|
||||||
idpUserID := dex.EncodeDexUserID("20", "zitadel")
|
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
|
||||||
|
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
|
||||||
return []*types.Account{
|
return []*types.Account{
|
||||||
{
|
{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
@@ -115,6 +117,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Services: []*rpservice.Service{
|
||||||
|
{
|
||||||
|
ID: "svc1",
|
||||||
|
Enabled: true,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{TargetType: "peer"},
|
||||||
|
{TargetType: "host"},
|
||||||
|
},
|
||||||
|
Auth: rpservice.AuthConfig{
|
||||||
|
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "svc2",
|
||||||
|
Enabled: false,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{TargetType: "domain"},
|
||||||
|
},
|
||||||
|
Auth: rpservice.AuthConfig{
|
||||||
|
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "2",
|
Id: "2",
|
||||||
@@ -180,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
|||||||
"1": {},
|
"1": {},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
oidcUserID: {
|
||||||
|
Id: oidcUserID,
|
||||||
|
IsServiceUser: false,
|
||||||
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
|
"1": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Networks: []*networkTypes.Network{
|
Networks: []*networkTypes.Network{
|
||||||
{
|
{
|
||||||
@@ -215,6 +249,11 @@ func (mockDatasource) GetStoreEngine() types.Engine {
|
|||||||
return types.FileStoreEngine
|
return types.FileStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCustomDomainsCounts returns test custom domain counts.
|
||||||
|
func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
|
||||||
|
return 3, 2, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||||
func TestGenerateProperties(t *testing.T) {
|
func TestGenerateProperties(t *testing.T) {
|
||||||
ds := mockDatasource{}
|
ds := mockDatasource{}
|
||||||
@@ -247,14 +286,14 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["rules"] != 4 {
|
if properties["rules"] != 4 {
|
||||||
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
||||||
}
|
}
|
||||||
if properties["users"] != 2 {
|
if properties["users"] != 3 {
|
||||||
t.Errorf("expected 1 users, got %d", properties["users"])
|
t.Errorf("expected 3 users, got %d", properties["users"])
|
||||||
}
|
}
|
||||||
if properties["setup_keys_usage"] != 2 {
|
if properties["setup_keys_usage"] != 2 {
|
||||||
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
|
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
|
||||||
}
|
}
|
||||||
if properties["pats"] != 4 {
|
if properties["pats"] != 5 {
|
||||||
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
|
t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
|
||||||
}
|
}
|
||||||
if properties["peers_ssh_enabled"] != 2 {
|
if properties["peers_ssh_enabled"] != 2 {
|
||||||
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
|
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
|
||||||
@@ -338,7 +377,90 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["local_users_count"] != 1 {
|
if properties["local_users_count"] != 1 {
|
||||||
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
|
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
|
||||||
}
|
}
|
||||||
if properties["idp_users_count"] != 1 {
|
if properties["idp_users_count"] != 2 {
|
||||||
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
|
t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
|
||||||
|
}
|
||||||
|
if properties["embedded_idp_users_local"] != 1 {
|
||||||
|
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
|
||||||
|
}
|
||||||
|
if properties["embedded_idp_users_zitadel"] != 1 {
|
||||||
|
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
|
||||||
|
}
|
||||||
|
if properties["embedded_idp_users_oidc"] != 1 {
|
||||||
|
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
|
||||||
|
}
|
||||||
|
if properties["embedded_idp_count"] != 3 {
|
||||||
|
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if properties["services"] != 2 {
|
||||||
|
t.Errorf("expected 2 services, got %v", properties["services"])
|
||||||
|
}
|
||||||
|
if properties["services_enabled"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
||||||
|
}
|
||||||
|
if properties["services_targets"] != 3 {
|
||||||
|
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
||||||
|
}
|
||||||
|
if properties["services_status_active"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
||||||
|
}
|
||||||
|
if properties["services_status_pending"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||||
|
}
|
||||||
|
if properties["services_status_error"] != 0 {
|
||||||
|
t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"])
|
||||||
|
}
|
||||||
|
if properties["services_target_type_peer"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"])
|
||||||
|
}
|
||||||
|
if properties["services_target_type_host"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"])
|
||||||
|
}
|
||||||
|
if properties["services_target_type_domain"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||||
|
}
|
||||||
|
if properties["services_auth_password"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||||
|
}
|
||||||
|
if properties["services_auth_oidc"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"])
|
||||||
|
}
|
||||||
|
if properties["services_auth_pin"] != 0 {
|
||||||
|
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||||
|
}
|
||||||
|
if properties["custom_domains"] != int64(3) {
|
||||||
|
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||||
|
}
|
||||||
|
if properties["custom_domains_validated"] != int64(2) {
|
||||||
|
t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractIdpType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
connectorID string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"okta-abc123def", "okta"},
|
||||||
|
{"zitadel-d5uv82dra0haedlf6kv0", "zitadel"},
|
||||||
|
{"entra-xyz789", "entra"},
|
||||||
|
{"google-abc123", "google"},
|
||||||
|
{"pocketid-abc123", "pocketid"},
|
||||||
|
{"microsoft-abc123", "microsoft"},
|
||||||
|
{"authentik-abc123", "authentik"},
|
||||||
|
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
|
||||||
|
{"local", "local"},
|
||||||
|
{"d6jvvp69kmnc73c9pl40", "oidc"},
|
||||||
|
{"", "oidc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.connectorID, func(t *testing.T) {
|
||||||
|
result := extractIdpType(tt.connectorID)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -148,7 +148,7 @@ type MockAccountManager struct {
|
|||||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||||
// Mock implementation - no-op
|
// Mock implementation - no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,7 +407,7 @@ func (am *MockAccountManager) AddPeer(
|
|||||||
|
|
||||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||||
if am.GetGroupFunc != nil {
|
if am.GetGroupByNameFunc != nil {
|
||||||
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -37,19 +37,19 @@ type managerImpl struct {
|
|||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
groupsManager groups.Manager
|
groupsManager groups.Manager
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
reverseProxyManager reverseproxy.Manager
|
serviceManager service.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockManager struct {
|
type mockManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
|
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
groupsManager: groupsManager,
|
groupsManager: groupsManager,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
reverseProxyManager: reverseproxyManager,
|
serviceManager: reverseproxyManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
|||||||
|
|
||||||
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
||||||
go func() {
|
go func() {
|
||||||
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
||||||
}
|
}
|
||||||
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user