refactor layout and structure

This commit is contained in:
Alisdair MacLeod
2026-01-21 13:52:22 +00:00
parent 2851e38a1f
commit 1d8390b935
51 changed files with 2298 additions and 4430 deletions

1
go.mod
View File

@@ -40,6 +40,7 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.24

2
go.sum
View File

@@ -103,6 +103,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=

27
proxy/.gitignore vendored
View File

@@ -1,27 +0,0 @@
# Binaries
bin/
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary
*.test
# Output of go coverage tool
*.out
# Configuration files (keep example)
config.json
# IDE files
.vscode/
.idea/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db

View File

@@ -1,83 +0,0 @@
.PHONY: build clean run test help version proto
# Build variables
BINARY_NAME=proxy
BUILD_DIR=bin
# Version variables (can be overridden)
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
BUILD_DATE ?= $(shell date -u '+%Y-%m-%d_%H:%M:%S')
# Go linker flags for version injection
LDFLAGS=-ldflags "-X github.com/netbirdio/netbird/proxy/pkg/version.Version=$(VERSION) \
-X github.com/netbirdio/netbird/proxy/pkg/version.Commit=$(COMMIT) \
-X github.com/netbirdio/netbird/proxy/pkg/version.BuildDate=$(BUILD_DATE)"
# Build the binary
build:
@echo "Building $(BINARY_NAME)..."
@echo "Version: $(VERSION)"
@echo "Commit: $(COMMIT)"
@echo "BuildDate: $(BUILD_DATE)"
@mkdir -p $(BUILD_DIR)
GOWORK=off go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) .
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)"
# Show version information
version:
@echo "Version: $(VERSION)"
@echo "Commit: $(COMMIT)"
@echo "BuildDate: $(BUILD_DATE)"
# Clean build artifacts
clean:
@echo "Cleaning..."
@rm -rf $(BUILD_DIR)
@go clean
@echo "Clean complete"
# Run the application (requires NB_PROXY_TARGET_URL to be set)
run: build
@./$(BUILD_DIR)/$(BINARY_NAME)
# Run tests
test:
GOWORK=off go test -v ./...
# Install dependencies
deps:
@echo "Installing dependencies..."
GOWORK=off go mod download
GOWORK=off go mod tidy
@echo "Dependencies installed"
# Format code
fmt:
@echo "Formatting code..."
@go fmt ./...
@echo "Format complete"
# Lint code
lint:
@echo "Linting code..."
@golangci-lint run
@echo "Lint complete"
# Generate protobuf files
proto:
@echo "Generating protobuf files..."
@./scripts/generate-proto.sh
# Show help
help:
@echo "Available targets:"
@echo " build - Build the binary"
@echo " clean - Remove build artifacts"
@echo " run - Build and run the application"
@echo " test - Run tests"
@echo " proto - Generate protobuf files"
@echo " deps - Install dependencies"
@echo " fmt - Format code"
@echo " lint - Lint code"
@echo " help - Show this help message"

View File

@@ -1,177 +1,50 @@
# Netbird Reverse Proxy
A lightweight, configurable reverse proxy server with graceful shutdown support.
The NetBird Reverse Proxy is a separate service that can act as a public entrypoint to certain resources within a NetBird network.
At a high level, the way that it operates is:
- Configured routes are communicated from the Management server to the proxy.
- For each route the proxy creates a NetBird connection to the NetBird Peer that hosts the resource.
- When traffic hits the proxy at the address and path configured for the proxied resource, the NetBird Proxy brings up a relevant authentication method for that resource.
- On successful authentication the proxy will forward traffic onwards to the NetBird Peer.
## Features
Proxy Authentication methods supported are:
- No authentication
- Oauth2/OIDC
- Emailed Magic Link
- Simple PIN
- HTTP Basic Auth Username and Password
- Simple reverse proxy with customizable headers
- Configuration via environment variables or JSON file
- Graceful shutdown with configurable timeout
- Structured logging with logrus
- Configurable timeouts (read, write, idle)
- Health monitoring support
## Management Connection
## Building
The Proxy communicates with the Management server over a gRPC connection.
Proxies act as clients to the Management server, the following RPCs are used:
- Server-side streaming for proxied service updates.
- Client-side streaming for proxy logs.
```bash
# Build the binary
GOWORK=off go build -o bin/proxy ./cmd/proxy
## Authentication
# Or use make if available
make build
```
When a request hits the Proxy, it looks up the permitted authentication methods for the Host domain.
If no authentication methods are registered for the Host domain, then no authentication will be applied (for fully public resources).
If any authentication methods are registered for the Host domain, then the Proxy will first serve an authentication page allowing the user to select an authentication method (from the permitted methods) and enter the required information for that authentication method.
If the user is successfully authenticated, their request will be forwarded through to the Proxy to be proxied to the relevant Peer.
Successful authentication does not guarantee a successful forwarding of the request as there may be failures behind the Proxy, such as with Peer connectivity or the underlying resource.
## TLS
Due to the authentication provided, the Proxy uses HTTPS for its endpoint, even if the underlying service is HTTP.
Certificate generation can either be via ACME (by default, using Let's Encrypt, but alternative ACME providers can be used) or through certificate files.
When not using ACME, the proxy server attempts to load a certificate and key from the files `tls.crt` and `tls.key` in a specified certificate directory.
When using ACME, the proxy server will store generated certificates in the specified certificate directory.
## Configuration
The proxy can be configured using either environment variables or a JSON configuration file. Environment variables take precedence over file-based configuration.
### Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
| `NB_PROXY_LISTEN_ADDRESS` | Address to listen on | `:8080` |
| `NB_PROXY_TARGET_URL` | Target URL to proxy requests to | **(required)** |
| `NB_PROXY_READ_TIMEOUT` | Read timeout duration | `30s` |
| `NB_PROXY_WRITE_TIMEOUT` | Write timeout duration | `30s` |
| `NB_PROXY_IDLE_TIMEOUT` | Idle timeout duration | `60s` |
| `NB_PROXY_SHUTDOWN_TIMEOUT` | Graceful shutdown timeout | `10s` |
| `NB_PROXY_LOG_LEVEL` | Log level (debug, info, warn, error) | `info` |
### Configuration File
Create a JSON configuration file:
```json
{
"listen_address": ":8080",
"target_url": "http://localhost:3000",
"read_timeout": "30s",
"write_timeout": "30s",
"idle_timeout": "60s",
"shutdown_timeout": "10s",
"log_level": "info"
}
```
## Usage
### Using Environment Variables
```bash
export NB_PROXY_TARGET_URL=http://localhost:3000
export NB_PROXY_LOG_LEVEL=debug
./bin/proxy
```
### Using Configuration File
```bash
./bin/proxy -config config.json
```
### Combining Both
Environment variables override file configuration:
```bash
export NB_PROXY_LOG_LEVEL=debug
./bin/proxy -config config.json
```
### Docker Example
```bash
docker run -e NB_PROXY_TARGET_URL=http://backend:3000 \
-e NB_PROXY_LISTEN_ADDRESS=:8080 \
-p 8080:8080 \
netbird-proxy
```
## Architecture
The application follows a clean architecture with clear separation of concerns:
```
proxy/
├── cmd/
│ └── proxy/
│ └── main.go # Entry point, CLI handling, signal management
├── config.go # Configuration loading and validation
├── server.go # Server lifecycle (Start/Stop)
├── go.mod # Module dependencies
└── README.md
```
### Key Components
- **config.go**: Handles configuration loading from environment variables and files using the `github.com/caarlos0/env/v11` library
- **server.go**: Encapsulates the HTTP server and reverse proxy logic with proper lifecycle management
- **cmd/proxy/main.go**: Entry point that orchestrates startup, graceful shutdown, and signal handling
## Graceful Shutdown
The server handles SIGINT and SIGTERM signals for graceful shutdown:
1. Signal received (Ctrl+C or kill command)
2. Server stops accepting new connections
3. Existing connections are allowed to complete within the shutdown timeout
4. Server exits cleanly
Press `Ctrl+C` to trigger graceful shutdown:
```bash
^C2026-01-13 22:40:00 INFO Received signal: interrupt
2026-01-13 22:40:00 INFO Shutting down server gracefully...
2026-01-13 22:40:00 INFO Server stopped successfully
2026-01-13 22:40:00 INFO Server exited successfully
```
## Headers
The proxy automatically sets the following headers on proxied requests:
- `X-Forwarded-Host`: Original request host
- `X-Origin-Host`: Target backend host
- `X-Real-IP`: Client's remote address
## Error Handling
- Invalid backend connections return `502 Bad Gateway`
- All proxy errors are logged with details
- Configuration errors are reported at startup
## Development
### Prerequisites
- Go 1.25 or higher
- Access to `github.com/sirupsen/logrus`
- Access to `github.com/caarlos0/env/v11`
### Testing Locally
Start a test backend:
```bash
# Terminal 1: Start a simple backend
python3 -m http.server 3000
```
Start the proxy:
```bash
# Terminal 2: Start the proxy
export NB_PROXY_TARGET_URL=http://localhost:3000
./bin/proxy
```
Test the proxy:
```bash
# Terminal 3: Make requests
curl http://localhost:8080
```
## License
Part of the Netbird project.
NetBird Proxy deployment configuration is via flags or environment variables, with flags taking precedence over the environment.
The following deployment configuration is available:
| Flag | Env | Purpose | Default |
+------+-----+---------+---------+
| `-mgmt` | `NB_PROXY_MANAGEMENT_ADDRESS` | The address of the management server for the proxy to get configuration from. | `"https://api.netbird.io:443"` |
| `-addr` | `NB_PROXY_ADDRESS` | The address that the reverse proxy will listen on. | `":443` |
| `-cert-dir` | `NB_PROXY_CERTIFICATE_DIRECTORY` | The location that certficates are stored in. | `"./certs"` |
| `-acme-certs` | `NB_PROXY_ACME_CERTIFICATES` | Whether to use ACME to generate certificates. | `false` |
| `-acme-addr` | `NB_PROXY_ACME_ADDRESS` | The HTTP address the proxy will listen on to respond to HTTP-01 ACME challenges | `":80"` |
| `-acme-dir` | `NB_PROXY_ACME_DIRECTORY` | The directory URL of the ACME server to be used | `"https://acme-v02.api.letsencrypt.org/directory"` |

84
proxy/cmd/proxy/main.go Normal file
View File

@@ -0,0 +1,84 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"log/slog"
"os"
"runtime"
"strings"
"github.com/netbirdio/netbird/proxy"
"golang.org/x/crypto/acme"
)
const DefaultManagementURL = "https://api.netbird.io:443"
var (
// Version is the application version (set via ldflags during build)
Version = "dev"
// Commit is the git commit hash (set via ldflags during build)
Commit = "unknown"
// BuildDate is the build date (set via ldflags during build)
BuildDate = "unknown"
// GoVersion is the Go version used to build the binary
GoVersion = runtime.Version()
)
func envBoolOrDefault(key string, def bool) bool {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
return v == strings.ToLower("true")
}
func envStringOrDefault(key string, def string) string {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
return v
}
func main() {
var (
version, acmeCerts bool
mgmtAddr, addr, certDir, acmeAddr, acmeDir string
)
flag.BoolVar(&version, "v", false, "Print version and exit")
flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.")
flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.")
flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ")
flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.")
flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.")
flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.")
flag.Parse()
if version {
fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion)
os.Exit(0)
}
// Write error logs to stderr.
errorLog := slog.New(slog.NewTextHandler(os.Stderr, nil))
srv := proxy.Server{
ErrorLog: errorLog,
ManagementAddress: mgmtAddr,
CertificateDirectory: certDir,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,
}
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
log.Fatal(err)
}
}

View File

@@ -1,109 +0,0 @@
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/proxy/pkg/proxy"
"github.com/netbirdio/netbird/proxy/pkg/version"
)
var (
configFile string
rootCmd = &cobra.Command{
Use: "proxy",
Short: "Netbird Reverse Proxy Server",
Long: "A lightweight, configurable reverse proxy server.",
SilenceUsage: true,
SilenceErrors: true,
RunE: run,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "path to JSON configuration file (optional, can use env vars instead)")
rootCmd.Version = version.Short()
rootCmd.SetVersionTemplate("{{.Version}}\n")
}
// Execute runs the root command
func Execute() error {
return rootCmd.Execute()
}
func run(cmd *cobra.Command, args []string) error {
config, err := proxy.LoadFromFileOrEnv(configFile)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
return err
}
setupLogging(config.LogLevel)
log.Infof("Starting Netbird Proxy - %s", version.Short())
log.Debugf("Full version info: %s", version.String())
log.Info("Configuration loaded successfully")
log.Infof("Listen Address: %s", config.ReverseProxy.ListenAddress)
log.Infof("Log Level: %s", config.LogLevel)
server, err := proxy.NewServer(config)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
return err
}
serverErrors := make(chan error, 1)
go func() {
if err := server.Start(); err != nil {
serverErrors <- err
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-serverErrors:
log.Fatalf("Server error: %v", err)
return err
case sig := <-quit:
log.Infof("Received signal: %v", sig)
}
ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
defer cancel()
if err := server.Stop(ctx); err != nil {
log.Fatalf("Failed to stop server gracefully: %v", err)
return err
}
log.Info("Server exited successfully")
return nil
}
func setupLogging(level string) {
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
switch level {
case "debug":
log.SetLevel(log.DebugLevel)
case "info":
log.SetLevel(log.InfoLevel)
case "warn":
log.SetLevel(log.WarnLevel)
case "error":
log.SetLevel(log.ErrorLevel)
default:
log.SetLevel(log.InfoLevel)
}
}

View File

@@ -1,30 +0,0 @@
{
"read_timeout": "30s",
"write_timeout": "30s",
"idle_timeout": "60s",
"shutdown_timeout": "10s",
"log_level": "info",
"grpc_listen_address": ":50051",
"proxy_id": "proxy-1",
"enable_grpc": true,
"reverse_proxy": {
"listen_address": ":443",
"management_url": "https://api.netbird.io",
"http_listen_address": ":80",
"cert_mode": "letsencrypt",
"tls_email": "your-email@example.com",
"cert_cache_dir": "./certs",
"oidc_config": {
"provider_url": "https://your-oidc-provider.com",
"client_id": "your-client-id",
"client_secret": "your-client-secret-if-needed",
"redirect_url": "http://localhost:80/auth/callback",
"scopes": ["openid", "profile", "email"],
"jwt_keys_location": "https://your-oidc-provider.com/.well-known/jwks.json",
"jwt_issuer": "https://your-oidc-provider.com/",
"jwt_audience": ["your-api-identifier-or-client-id"],
"jwt_idp_signkey_refresh_enabled": false,
"session_cookie_name": "auth_session"
}
}
}

View File

@@ -1,22 +0,0 @@
module github.com/netbirdio/netbird/proxy
go 1.25
require (
github.com/caarlos0/env/v11 v11.3.1
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.10.2
golang.org/x/crypto v0.44.0
google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/stretchr/testify v1.11.1 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect
)

View File

@@ -1,65 +0,0 @@
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,89 @@
package accesslog
import (
"context"
"log/slog"
"github.com/netbirdio/netbird/shared/management/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
)
type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
type errorLogger interface {
ErrorContext(ctx context.Context, msg string, args ...any)
}
type Logger struct {
client gRPCClient
errorLog errorLogger
}
func NewLogger(client gRPCClient, errorLog errorLogger) *Logger {
if errorLog == nil {
errorLog = slog.New(slog.DiscardHandler)
}
return &Logger{
client: client,
errorLog: errorLog,
}
}
type logEntry struct {
ServiceId string
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIp string
AuthMechanism string
UserId string
AuthSuccess bool
}
func (l *Logger) log(ctx context.Context, log logEntry) {
// Fire off the log request in a separate routine.
// This increases the possibility of losing a log message
// (although it should still get logged in the event of an error),
// but it will reduce latency returning the request in the
// middleware.
// There is also a chance that log messages will arrive at
// the server out of order; however, the timestamp should
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
go func() {
if _, err := l.client.SendAccessLog(ctx, &proto.SendAccessLogRequest{
Log: &proto.AccessLog{
Timestamp: now,
ServiceId: log.ServiceId,
Host: log.Host,
Path: log.Path,
DurationMs: log.DurationMs,
Method: log.Method,
ResponseCode: log.ResponseCode,
SourceIp: log.SourceIp,
AuthMechanism: log.AuthMechanism,
UserId: log.UserId,
AuthSuccess: log.AuthSuccess,
},
}); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log.
l.errorLog.ErrorContext(ctx, "Error sending access log on gRPC connection",
"service_id", log.ServiceId,
"host", log.Host,
"path", log.Path,
"duration", log.DurationMs,
"method", log.Method,
"response_code", log.ResponseCode,
"source_ip", log.SourceIp,
"auth_mechanism", log.AuthMechanism,
"user_id", log.UserId,
"auth_success", log.AuthSuccess,
"error", err)
}
}()
}

View File

@@ -0,0 +1,47 @@
package accesslog
import (
"net"
"net/http"
"time"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use a response writer wrapper so we can access the status code later.
sw := &statusWriter{
w: w,
status: http.StatusOK, // Default status is OK unless otherwise modified.
}
// Get the source IP before passing the request on as the proxy will modify
// headers that we wish to use to gather that information on the request.
sourceIp := extractSourceIP(r)
start := time.Now()
next.ServeHTTP(sw, r)
duration := time.Since(start)
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// Fallback to just using the full host value.
host = r.Host
}
l.log(r.Context(), logEntry{
ServiceId: proxy.ServiceIdFromContext(r.Context()),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(sw.status),
SourceIp: sourceIp,
AuthMechanism: auth.MethodFromContext(r.Context()).String(),
UserId: auth.UserFromContext(r.Context()),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
})
})
}

View File

@@ -0,0 +1,43 @@
package accesslog
import (
"net"
"net/http"
"slices"
"strings"
)
// requestIP attempts to extract the source IP from a request.
// Adapted from https://husobee.github.io/golang/ip-address/2015/12/17/remote-ip-go.html
// with the addition of some newer stdlib functions that are now
// available.
// The concept here is to look backwards through IP headers until
// the first public IP address is found. The hypothesis is that
// even if there are multiple IP addresses specified in these headers,
// the last public IP should be the hop immediately before reaching
// the server and therefore represents the "true" source IP regardless
// of the number of intermediate proxies or network hops.
func extractSourceIP(r *http.Request) string {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP"} {
addresses := strings.Split(r.Header.Get(h), ",")
// Iterate from right to left until we get a public address
// that should be the address right before our proxy.
for _, address := range slices.Backward(addresses) {
// Trim the address because sometimes clients put whitespace in there.
ip := strings.TrimSpace(address)
// Parse the IP so that we can easily check whether it is a valid public address.
realIP := net.ParseIP(ip)
if !realIP.IsGlobalUnicast() || realIP.IsPrivate() || realIP.IsLoopback() {
continue
}
return ip
}
}
// Fallback to the requests RemoteAddr, this is least likely to be correct but
// should at least yield something in the event that the above has failed.
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
return ip
}

View File

@@ -0,0 +1,26 @@
package accesslog
import (
"net/http"
)
// statusWriter is a simple wrapper around an http.ResponseWriter
// that captures the setting of the status code via the WriteHeader
// function and stores it so that it can be retrieved later.
type statusWriter struct {
w http.ResponseWriter
status int
}
func (w *statusWriter) Header() http.Header {
return w.w.Header()
}
func (w *statusWriter) Write(data []byte) (int, error) {
return w.w.Write(data)
}
func (w *statusWriter) WriteHeader(status int) {
w.status = status
w.w.WriteHeader(status)
}

View File

@@ -0,0 +1,53 @@
package acme
import (
"context"
"fmt"
"sync"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
type Manager struct {
*autocert.Manager
domainsMux sync.RWMutex
domains map[string]struct{}
}
func NewManager(certDir, acmeURL string) *Manager {
mgr := &Manager{
domains: make(map[string]struct{}),
}
mgr.Manager = &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: mgr.hostPolicy,
Cache: autocert.DirCache(certDir),
Client: &acme.Client{
DirectoryURL: acmeURL,
},
}
return mgr
}
func (mgr *Manager) hostPolicy(_ context.Context, domain string) error {
mgr.domainsMux.RLock()
defer mgr.domainsMux.RUnlock()
if _, exists := mgr.domains[domain]; exists {
return nil
}
return fmt.Errorf("unknown domain %q", domain)
}
func (mgr *Manager) AddDomain(domain string) {
mgr.domainsMux.Lock()
defer mgr.domainsMux.Unlock()
mgr.domains[domain] = struct{}{}
}
func (mgr *Manager) RemoveDomain(domain string) {
mgr.domainsMux.Lock()
defer mgr.domainsMux.Unlock()
delete(mgr.domains, domain)
}

View File

@@ -0,0 +1,4 @@
<!doctype html>
{{ range . }}
<p>{{ . }}</p>
{{ end }}

View File

@@ -0,0 +1,42 @@
package auth
import (
"crypto/subtle"
"net/http"
)
type BasicAuth struct {
username, password string
}
func NewBasicAuth(username string, password string) BasicAuth {
return BasicAuth{
username: username,
password: password,
}
}
func (BasicAuth) Type() Method {
return MethodBasicAuth
}
func (b BasicAuth) Authenticate(r *http.Request) (string, bool, any) {
username, password, ok := r.BasicAuth()
if !ok {
return "", false, nil
}
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(b.username)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(b.password)) == 1
// If authenticated, then return the username.
if usernameMatch && passwordMatch {
return username, false, nil
}
return "", false, nil
}
func (b BasicAuth) Middleware(next http.Handler) http.Handler {
return next
}

View File

@@ -1,25 +0,0 @@
package auth
import "github.com/netbirdio/netbird/proxy/internal/auth/methods"
// Config holds the authentication configuration for a route
// Only ONE auth method should be configured per route
type Config struct {
// HTTP Basic authentication (username/password)
BasicAuth *methods.BasicAuthConfig
// PIN authentication
PIN *methods.PINConfig
// Bearer token with JWT validation and OAuth/OIDC flow
// When enabled, uses the global OIDCConfig from proxy Config
Bearer *methods.BearerConfig
}
// IsEmpty returns true if no auth methods are configured
func (c *Config) IsEmpty() bool {
if c == nil {
return true
}
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
}

View File

@@ -0,0 +1,38 @@
package auth
import (
"context"
)
type requestContextKey string
const (
authMethodKey requestContextKey = "authMethod"
authUserKey requestContextKey = "authUser"
)
func withAuthMethod(ctx context.Context, method Method) context.Context {
return context.WithValue(ctx, authMethodKey, method)
}
func MethodFromContext(ctx context.Context) Method {
v := ctx.Value(authMethodKey)
method, ok := v.(Method)
if !ok {
return ""
}
return method
}
func withAuthUser(ctx context.Context, userId string) context.Context {
return context.WithValue(ctx, authUserKey, userId)
}
func UserFromContext(ctx context.Context) string {
v := ctx.Value(authUserKey)
userId, ok := v.(string)
if !ok {
return ""
}
return userId
}

View File

@@ -1,25 +0,0 @@
package methods
import (
"crypto/subtle"
"net/http"
)
// BasicAuthConfig holds HTTP Basic authentication settings
type BasicAuthConfig struct {
Username string
Password string
}
// Validate checks Basic Auth credentials from the request
func (c *BasicAuthConfig) Validate(r *http.Request) bool {
username, password, ok := r.BasicAuth()
if !ok {
return false
}
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
return usernameMatch && passwordMatch
}

View File

@@ -1,8 +0,0 @@
package methods
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
// This just enables Bearer auth for a specific route
type BearerConfig struct {
Enabled bool
}

View File

@@ -1,32 +0,0 @@
package methods
import (
"crypto/subtle"
"net/http"
)
const (
// DefaultPINHeader is the default header name for PIN authentication
DefaultPINHeader = "X-PIN"
)
// PINConfig holds PIN authentication settings
type PINConfig struct {
PIN string
Header string
}
// Validate checks PIN from the request header
func (c *PINConfig) Validate(r *http.Request) bool {
header := c.Header
if header == "" {
header = DefaultPINHeader
}
providedPIN := r.Header.Get(header)
if providedPIN == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
}

View File

@@ -1,298 +1,198 @@
package auth
import (
"fmt"
"crypto/rand"
_ "embed"
"encoding/base64"
"html/template"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
"sync"
"time"
)
// Middleware wraps an HTTP handler with authentication middleware
//go:embed auth.gohtml
var authTemplate string
type Method string
var (
MethodBasicAuth Method = "basic"
MethodPIN Method = "pin"
MethodBearer Method = "bearer"
)
func (m Method) String() string {
return string(m)
}
const (
sessionCookieName = "nb_session"
sessionExpiration = 24 * time.Hour
)
type session struct {
UserID string
Method Method
CreatedAt time.Time
}
type Scheme interface {
Type() Method
// Authenticate should check the passed request and determine whether
// it represents an authenticated user request. If it does not, then
// an empty string should indicate an unauthenticated request which
// will be rejected; optionally, it can also return any data that should
// be included in a UI template when prompting the user to authenticate.
// If the request is authenticated, then a user id should be returned
// along with a boolean indicating whether a redirect is needed to clean
// up authentication artifacts from the URLs query.
Authenticate(*http.Request) (userid string, needsRedirect bool, promptData any)
// Middleware is applied within the outer auth middleware, but they will
// be applied after authentication if no scheme has authenticated a
// request.
// If no scheme Middleware blocks the request processing, then the auth
// middleware will then present the user with the auth UI.
Middleware(http.Handler) http.Handler
}
type Middleware struct {
next http.Handler
config *Config
routeID string
rejectResponse func(w http.ResponseWriter, r *http.Request)
oidcHandler *oidc.Handler // OIDC handler for OAuth flow (contains config and JWT validator)
domainsMux sync.RWMutex
domains map[string][]Scheme
sessionsMux sync.RWMutex
sessions map[string]*session
}
// authResult holds the result of an authentication attempt
type authResult struct {
authenticated bool
method string
userID string
func NewMiddleware() *Middleware {
mw := &Middleware{
domains: make(map[string][]Scheme),
sessions: make(map[string]*session),
}
// TODO: goroutine is leaked here.
go mw.cleanupSessions()
return mw
}
// ServeHTTP implements the http.Handler interface
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if m.config.IsEmpty() {
m.allowWithoutAuth(w, r)
return
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
tmpl := template.Must(template.New("auth").Parse(authTemplate))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw.domainsMux.RLock()
schemes, exists := mw.domains[r.Host]
mw.domainsMux.RUnlock()
result := m.authenticate(w, r)
if result == nil {
// Authentication triggered a redirect (e.g., OIDC flow)
return
}
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(schemes) == 0 {
next.ServeHTTP(w, r)
return
}
if !result.authenticated {
m.rejectRequest(w, r)
return
}
// Check for an existing session to avoid users having to authenticate for every request.
// TODO: This does not work if you are load balancing across multiple proxy servers.
if cookie, err := r.Cookie(sessionCookieName); err == nil {
mw.sessionsMux.RLock()
sess, ok := mw.sessions[cookie.Value]
mw.sessionsMux.RUnlock()
if ok {
ctx := withAuthMethod(r.Context(), sess.Method)
ctx = withAuthUser(ctx, sess.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
m.continueWithAuth(w, r, result)
// Try to authenticate with each scheme.
methods := make(map[Method]any)
for _, s := range schemes {
userid, needsRedirect, promptData := s.Authenticate(r)
if userid != "" {
mw.createSession(w, r, userid, s.Type())
if needsRedirect {
// Clean the path and redirect to the naked URL.
// This is intended to prevent leaking potentially
// sensitive query parameters for some authentication
// methods such as OIDC.
http.Redirect(w, r, r.URL.Path, http.StatusFound)
return
}
ctx := withAuthMethod(r.Context(), s.Type())
ctx = withAuthUser(ctx, userid)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
methods[s.Type()] = promptData
}
// The handler is passed through the scheme middlewares,
// if none of them intercept the request, then this handler will
// be called and present the user with the authentication page.
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := tmpl.Execute(w, methods); err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
}
}))
// No authentication succeeded. Apply the scheme handlers.
for _, s := range schemes {
handler = s.Middleware(handler)
}
// Run the unauthenticated request against the scheme handlers and the final UI handler.
handler.ServeHTTP(w, r)
})
}
// allowWithoutAuth allows requests when no authentication is configured
func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"auth_method": "none",
"path": r.URL.Path,
}).Debug("No authentication configured, allowing request")
r.Header.Set("X-Auth-Method", "none")
m.next.ServeHTTP(w, r)
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = schemes
}
// authenticate attempts to authenticate the request using configured methods
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
if result := m.tryBasicAuth(r); result.authenticated {
return result
}
if result := m.tryPINAuth(r); result.authenticated {
return result
}
return m.tryBearerAuth(w, r)
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// tryBasicAuth attempts Basic authentication
func (m *Middleware) tryBasicAuth(r *http.Request) *authResult {
if m.config.BasicAuth == nil {
return &authResult{}
}
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
// Generate a random sessionID
b := make([]byte, 32)
_, _ = rand.Read(b)
sessionID := base64.URLEncoding.EncodeToString(b)
if !m.config.BasicAuth.Validate(r) {
return &authResult{}
mw.sessionsMux.Lock()
mw.sessions[sessionID] = &session{
UserID: userID,
Method: method,
CreatedAt: time.Now(),
}
mw.sessionsMux.Unlock()
result := &authResult{
authenticated: true,
method: "basic",
}
if username, _, ok := r.BasicAuth(); ok {
result.userID = username
}
return result
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: sessionID,
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
})
}
// tryPINAuth attempts PIN authentication
func (m *Middleware) tryPINAuth(r *http.Request) *authResult {
if m.config.PIN == nil {
return &authResult{}
}
if !m.config.PIN.Validate(r) {
return &authResult{}
}
return &authResult{
authenticated: true,
method: "pin",
userID: "pin_user",
}
}
// tryBearerAuth attempts Bearer token authentication with JWT validation
// Returns nil if OIDC redirect occurred
func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *authResult {
if m.config.Bearer == nil || m.oidcHandler == nil {
return &authResult{}
}
cookieName := m.oidcHandler.SessionCookieName()
if m.handleAuthTokenParameter(w, r, cookieName) {
return nil
}
if result := m.trySessionCookie(r, cookieName); result.authenticated {
return result
}
if result := m.tryAuthorizationHeader(r); result.authenticated {
return result
}
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
return nil
}
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
// Returns true if a redirect occurred
func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Request, cookieName string) bool {
authToken := r.URL.Query().Get("_auth_token")
if authToken == "" {
return false
}
log.WithFields(log.Fields{
"route_id": m.routeID,
"host": r.Host,
}).Info("Found auth token in query parameter, setting cookie and redirecting")
if !m.oidcHandler.ValidateJWT(authToken) {
log.WithFields(log.Fields{
"route_id": m.routeID,
}).Warn("Invalid token in query parameter")
return false
}
cookie := &http.Cookie{
Name: cookieName,
Value: authToken,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: true,
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
// Redirect to same URL without the token parameter
redirectURL := m.buildCleanRedirectURL(r)
log.WithFields(log.Fields{
"route_id": m.routeID,
"redirect_url": redirectURL,
}).Debug("Redirecting to clean URL after setting cookie")
http.Redirect(w, r, redirectURL, http.StatusFound)
return true
}
// buildCleanRedirectURL builds a redirect URL without the _auth_token parameter
func (m *Middleware) buildCleanRedirectURL(r *http.Request) string {
cleanURL := *r.URL
q := cleanURL.Query()
q.Del("_auth_token")
cleanURL.RawQuery = q.Encode()
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
}
// trySessionCookie attempts authentication using a session cookie
func (m *Middleware) trySessionCookie(r *http.Request, cookieName string) *authResult {
log.WithFields(log.Fields{
"route_id": m.routeID,
"cookie_name": cookieName,
"host": r.Host,
"path": r.URL.Path,
}).Debug("Checking for session cookie")
cookie, err := r.Cookie(cookieName)
if err != nil || cookie.Value == "" {
log.WithFields(log.Fields{
"route_id": m.routeID,
"error": err,
}).Debug("No session cookie found")
return &authResult{}
}
log.WithFields(log.Fields{
"route_id": m.routeID,
"cookie_name": cookieName,
}).Debug("Session cookie found, validating JWT")
if !m.oidcHandler.ValidateJWT(cookie.Value) {
log.WithFields(log.Fields{
"route_id": m.routeID,
}).Debug("JWT validation failed for session cookie")
return &authResult{}
}
return &authResult{
authenticated: true,
method: "bearer_session",
userID: m.oidcHandler.ExtractUserID(cookie.Value),
}
}
// tryAuthorizationHeader attempts authentication using the Authorization header
func (m *Middleware) tryAuthorizationHeader(r *http.Request) *authResult {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
return &authResult{}
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if !m.oidcHandler.ValidateJWT(token) {
return &authResult{}
}
return &authResult{
authenticated: true,
method: "bearer",
userID: m.oidcHandler.ExtractUserID(token),
}
}
// rejectRequest rejects an unauthenticated request
func (m *Middleware) rejectRequest(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"path": r.URL.Path,
}).Warn("Authentication failed")
if m.rejectResponse != nil {
m.rejectResponse(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}
// continueWithAuth continues the request with authenticated user info
func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, result *authResult) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"auth_method": result.method,
"user_id": result.userID,
"path": r.URL.Path,
}).Debug("Authentication successful")
// TODO: Find other means of auth logging than headers
r.Header.Set("X-Auth-Method", result.method)
r.Header.Set("X-Auth-User-ID", result.userID)
// Continue to next handler
m.next.ServeHTTP(w, r)
}
// Wrap wraps an HTTP handler with authentication middleware
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
if authConfig == nil {
authConfig = &Config{}
}
return &Middleware{
next: next,
config: authConfig,
routeID: routeID,
rejectResponse: rejectResponse,
oidcHandler: oidcHandler,
func (mw *Middleware) cleanupSessions() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-sessionExpiration)
mw.sessionsMux.Lock()
for id, sess := range mw.sessions {
if sess.CreatedAt.Before(cutoff) {
delete(mw.sessions, id)
}
}
mw.sessionsMux.Unlock()
}
}

202
proxy/internal/auth/oidc.go Normal file
View File

@@ -0,0 +1,202 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"log/slog"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
const stateExpiration = 10 * time.Minute
// OIDCConfig holds configuration for OIDC authentication
type OIDCConfig struct {
OIDCProviderURL string
OIDCClientID string
OIDCClientSecret string
OIDCRedirectURL string
OIDCScopes []string
}
// oidcState stores CSRF state with expiration
type oidcState struct {
OriginalURL string
CreatedAt time.Time
}
// OIDC implements the Scheme interface for JWT/OIDC authentication
type OIDC struct {
verifier *oidc.IDTokenVerifier
oauthConfig *oauth2.Config
states map[string]*oidcState
statesMux sync.RWMutex
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(ctx context.Context, cfg OIDCConfig) (*OIDC, error) {
if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" {
return nil, fmt.Errorf("OIDC provider URL and client ID are required")
}
scopes := cfg.OIDCScopes
if len(scopes) == 0 {
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
}
provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
}
o := &OIDC{
verifier: provider.Verifier(&oidc.Config{
ClientID: cfg.OIDCClientID,
}),
oauthConfig: &oauth2.Config{
ClientID: cfg.OIDCClientID,
ClientSecret: cfg.OIDCClientSecret,
RedirectURL: cfg.OIDCRedirectURL,
Scopes: scopes,
Endpoint: provider.Endpoint(),
},
states: make(map[string]*oidcState),
}
go o.cleanupStates()
return o, nil
}
func (*OIDC) Type() Method {
return MethodBearer
}
func (o *OIDC) Authenticate(r *http.Request) (string, bool, any) {
// Try Authorization: Bearer <token> header
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
return userID, false, nil
}
}
// Try _auth_token query parameter (from OIDC callback redirect)
if token := r.URL.Query().Get("_auth_token"); token != "" {
if userID := o.validateToken(r.Context(), token); userID != "" {
return userID, true, nil // Redirect needed to clean up URL
}
}
// If the request is not authenticated, return a redirect URL for the UI to
// route the user through if they select OIDC login.
b := make([]byte, 32)
_, _ = rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
// TODO: this does not work if you are load balancing across multiple proxy servers.
o.statesMux.Lock()
o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()}
o.statesMux.Unlock()
return "", false, o.oauthConfig.AuthCodeURL(state)
}
// Middleware returns an http.Handler that handles OIDC callback and flow initiation.
func (o *OIDC) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle OIDC callback
if r.URL.Path == "/oauth/callback" {
o.handleCallback(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// validateToken validates a JWT ID token and returns the user ID (subject)
func (o *OIDC) validateToken(ctx context.Context, token string) string {
if o.verifier == nil {
return ""
}
idToken, err := o.verifier.Verify(ctx, token)
if err != nil {
return ""
}
return idToken.Subject
}
// handleCallback processes the OIDC callback
func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify and consume state
o.statesMux.Lock()
st, ok := o.states[state]
if ok {
delete(o.states, state)
}
o.statesMux.Unlock()
if !ok {
http.Error(w, "Invalid or expired state", http.StatusBadRequest)
return
}
// Exchange code for token
token, err := o.oauthConfig.Exchange(r.Context(), code)
if err != nil {
slog.Error("Token exchange failed", "error", err)
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Prefer ID token if available
idToken := token.AccessToken
if id, ok := token.Extra("id_token").(string); ok && id != "" {
idToken = id
}
// Redirect back to original URL with token
origURL, err := url.Parse(st.OriginalURL)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
q := origURL.Query()
q.Set("_auth_token", idToken)
origURL.RawQuery = q.Encode()
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
// cleanupStates periodically removes expired states
func (o *OIDC) cleanupStates() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-stateExpiration)
o.statesMux.Lock()
for k, v := range o.states {
if v.CreatedAt.Before(cutoff) {
delete(o.states, k)
}
}
o.statesMux.Unlock()
}
}

View File

@@ -1,17 +0,0 @@
package oidc
// Config holds the global OIDC/OAuth configuration
type Config struct {
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"`
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"`
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"`
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"`
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"`
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"`
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"`
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"`
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"`
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"`
}

View File

@@ -1,285 +0,0 @@
package oidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/auth/jwt"
)
// Handler manages OIDC authentication flow
type Handler struct {
config *Config
stateStore *StateStore
jwtValidator *jwt.Validator
}
// NewHandler creates a new OIDC handler
func NewHandler(config *Config, stateStore *StateStore) *Handler {
var jwtValidator *jwt.Validator
if config.JWTKeysLocation != "" {
jwtValidator = jwt.NewValidator(
config.JWTIssuer,
config.JWTAudience,
config.JWTKeysLocation,
config.JWTIdpSignkeyRefreshEnabled,
)
}
return &Handler{
config: config,
stateStore: stateStore,
jwtValidator: jwtValidator,
}
}
// RedirectToProvider initiates the OAuth/OIDC authentication flow by redirecting to the provider
func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, routeID string) {
// Generate random state for CSRF protection
state, err := generateRandomString(32)
if err != nil {
log.WithError(err).Error("Failed to generate OIDC state")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Store state with original URL for redirect after auth
// Include the full URL with scheme and host
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
h.stateStore.Store(state, originalURL, routeID)
// Default scopes if not configured
scopes := h.config.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}
authURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
log.WithError(err).Error("Invalid OIDC provider URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Append /authorize if it doesn't exist (common OIDC endpoint)
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
}
params := url.Values{}
params.Set("client_id", h.config.ClientID)
params.Set("redirect_uri", h.config.RedirectURL)
params.Set("response_type", "code")
params.Set("scope", strings.Join(scopes, " "))
params.Set("state", state)
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
params.Set("audience", h.config.JWTAudience[0])
}
authURL.RawQuery = params.Encode()
log.WithFields(log.Fields{
"route_id": routeID,
"provider_url": authURL.String(),
"redirect_url": h.config.RedirectURL,
"state": state,
}).Info("Redirecting to OIDC provider for authentication")
http.Redirect(w, r, authURL.String(), http.StatusFound)
}
// HandleCallback creates an HTTP handler for the OIDC callback endpoint
func (h *Handler) HandleCallback() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get authorization code and state from query parameters
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
log.Error("Missing code or state in OIDC callback")
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify state to prevent CSRF
oidcSt, ok := h.stateStore.Get(state)
if !ok {
log.Error("Invalid or expired OIDC state")
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
return
}
// Delete state to prevent reuse
h.stateStore.Delete(state)
// Exchange authorization code for token
token, err := h.exchangeCodeForToken(code)
if err != nil {
log.WithError(err).Error("Failed to exchange code for token")
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Parse the original URL to add the token as a query parameter
origURL, err := url.Parse(oidcSt.OriginalURL)
if err != nil {
log.WithError(err).Error("Failed to parse original URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add token as query parameter so the original domain can set its own cookie
// We use a special parameter name that the auth middleware will look for
q := origURL.Query()
q.Set("_auth_token", token)
origURL.RawQuery = q.Encode()
log.WithFields(log.Fields{
"route_id": oidcSt.RouteID,
"original_url": oidcSt.OriginalURL,
"redirect_url": origURL.String(),
"callback_host": r.Host,
}).Info("OIDC authentication successful, redirecting with token parameter")
// Redirect back to original URL with token parameter
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
}
// exchangeCodeForToken exchanges an authorization code for an access token
func (h *Handler) exchangeCodeForToken(code string) (string, error) {
// Build token endpoint URL
tokenURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
}
// Auth0 uses /oauth/token, standard OIDC uses /token
// Check if path already contains token endpoint
if !strings.Contains(tokenURL.Path, "/token") {
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
}
// Build request body
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", h.config.RedirectURL)
data.Set("client_id", h.config.ClientID)
// Only include client_secret if it's provided (not needed for public/SPA clients)
if h.config.ClientSecret != "" {
data.Set("client_secret", h.config.ClientSecret)
}
// Make token exchange request
resp, err := http.PostForm(tokenURL.String(), data)
if err != nil {
return "", fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("no access token in response")
}
// Return the ID token if available (contains user claims), otherwise access token
if tokenResp.IDToken != "" {
return tokenResp.IDToken, nil
}
return tokenResp.AccessToken, nil
}
// ValidateJWT validates a JWT token
func (h *Handler) ValidateJWT(tokenString string) bool {
if h.jwtValidator == nil {
log.Error("JWT validation failed: JWT validator not initialized")
return false
}
// Validate the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
log.WithError(err).Error("JWT validation failed")
// Try to parse token without validation to see what's in it
parts := strings.Split(tokenString, ".")
if len(parts) == 3 {
// Decode payload (middle part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err == nil {
log.WithFields(log.Fields{
"payload": string(payload),
}).Debug("Token payload for debugging")
}
}
return false
}
// Token is valid if parsedToken is not nil and Valid is true
return parsedToken != nil && parsedToken.Valid
}
// ExtractUserID extracts the user ID from a JWT token
func (h *Handler) ExtractUserID(tokenString string) string {
if h.jwtValidator == nil {
return ""
}
// Parse the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
return ""
}
// parsedToken is already *jwtgo.Token from ValidateAndParse
// Create extractor to get user auth info
extractor := jwt.NewClaimsExtractor()
userAuth, err := extractor.ToUserAuth(parsedToken)
if err != nil {
log.WithError(err).Debug("Failed to extract user ID from JWT")
return ""
}
return userAuth.UserId
}
// SessionCookieName returns the configured session cookie name or default
func (h *Handler) SessionCookieName() string {
if h.config.SessionCookieName != "" {
return h.config.SessionCookieName
}
return "auth_session"
}

View File

@@ -1,10 +0,0 @@
package oidc
import "time"
// State represents stored OIDC state information for CSRF protection
type State struct {
OriginalURL string
CreatedAt time.Time
RouteID string
}

View File

@@ -1,65 +0,0 @@
package oidc
import (
"sync"
"time"
)
const (
// StateExpiration is how long OIDC state tokens are valid
StateExpiration = 10 * time.Minute
)
// StateStore manages OIDC state tokens for CSRF protection
type StateStore struct {
mu sync.RWMutex
states map[string]*State
}
// NewStateStore creates a new OIDC state store
func NewStateStore() *StateStore {
return &StateStore{
states: make(map[string]*State),
}
}
// Store saves a state token with associated metadata
func (s *StateStore) Store(stateToken, originalURL, routeID string) {
s.mu.Lock()
defer s.mu.Unlock()
s.states[stateToken] = &State{
OriginalURL: originalURL,
CreatedAt: time.Now(),
RouteID: routeID,
}
s.cleanup()
}
// Get retrieves a state by token
func (s *StateStore) Get(stateToken string) (*State, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
st, ok := s.states[stateToken]
return st, ok
}
// Delete removes a state token
func (s *StateStore) Delete(stateToken string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.states, stateToken)
}
// cleanup removes expired state tokens (must be called with lock held)
func (s *StateStore) cleanup() {
cutoff := time.Now().Add(-StateExpiration)
for k, v := range s.states {
if v.CreatedAt.Before(cutoff) {
delete(s.states, k)
}
}
}

View File

@@ -1,15 +0,0 @@
package oidc
import (
"crypto/rand"
"encoding/base64"
)
// generateRandomString generates a cryptographically secure random string of the specified length
func generateRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}

View File

@@ -0,0 +1,45 @@
package auth
import (
"crypto/subtle"
"net/http"
)
const (
userId = "pin-user"
formId = "pin"
)
type Pin struct {
pin string
}
func NewPin(pin string) Pin {
return Pin{
pin: pin,
}
}
func (Pin) Type() Method {
return MethodPIN
}
// Authenticate attempts to authenticate the request using a form
// value passed in the request.
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Pin) Authenticate(r *http.Request) (string, bool, any) {
pin := r.FormValue(formId)
// Compare the passed pin with the expected pin.
if subtle.ConstantTimeCompare([]byte(pin), []byte(p.pin)) == 1 {
return userId, false, nil
}
return "", false, formId
}
func (p Pin) Middleware(next http.Handler) http.Handler {
return next
}

View File

@@ -0,0 +1,24 @@
package proxy
import (
"context"
)
type requestContextKey string
const (
serviceIdKey requestContextKey = "serviceId"
)
func withServiceId(ctx context.Context, serviceId string) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) string {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(string)
if !ok {
return ""
}
return serviceId
}

View File

@@ -0,0 +1,44 @@
package proxy
import (
"net/http"
"net/http/httputil"
"sync"
)
type ReverseProxy struct {
transport http.RoundTripper
mappingsMux sync.RWMutex
mappings map[string]Mapping
}
// NewReverseProxy configures a new NetBird ReverseProxy.
// This is a wrapper around an httputil.ReverseProxy set
// to dynamically route requests based on internal mapping
// between requested URLs and targets.
// The internal mappings can be modified using the AddMapping
// and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper) *ReverseProxy {
return &ReverseProxy{
transport: transport,
mappings: make(map[string]Mapping),
}
}
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
target, serviceId, exists := p.findTargetForRequest(r)
if !exists {
// No mapping found so return an error here.
// TODO: prettier error page.
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), serviceId)
// Set up a reverse proxy using the transport and then use it to serve the request.
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.Transport = p.transport
proxy.ServeHTTP(w, r.WithContext(ctx))
}

View File

@@ -0,0 +1,62 @@
package proxy
import (
"net/http"
"net/url"
"sort"
"strings"
)
type Mapping struct {
ID string
Host string
Paths map[string]*url.URL
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, bool) {
p.mappingsMux.RLock()
if p.mappings == nil {
p.mappingsMux.RUnlock()
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings = make(map[string]Mapping)
// There cannot be any loaded Mappings as we have only just initialized.
return nil, "", false
}
defer p.mappingsMux.RUnlock()
m, exists := p.mappings[req.Host]
if !exists {
return nil, "", false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
return m.Paths[path], m.ID, true
}
}
return nil, "", false
}
func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
if p.mappings == nil {
p.mappings = make(map[string]Mapping)
}
p.mappings[m.Host] = m
}
func (p *ReverseProxy) RemoveMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
delete(p.mappings, m.Host)
}

View File

@@ -1,28 +0,0 @@
package certmanager
import (
"context"
"crypto/tls"
"net/http"
)
// Manager defines the interface for certificate management
type Manager interface {
// IsEnabled returns whether certificate management is enabled
IsEnabled() bool
// AddDomain adds a domain to the allowed hosts list
AddDomain(domain string)
// RemoveDomain removes a domain from the allowed hosts list
RemoveDomain(domain string)
// IssueCertificate eagerly issues a certificate for a domain
IssueCertificate(ctx context.Context, domain string) error
// TLSConfig returns the TLS configuration for the HTTPS server
TLSConfig() *tls.Config
// HTTPHandler returns the HTTP handler for ACME challenges (or fallback)
HTTPHandler(fallback http.Handler) http.Handler
}

View File

@@ -1,111 +0,0 @@
package certmanager
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
)
// LetsEncryptManager handles TLS certificate issuance via Let's Encrypt
type LetsEncryptManager struct {
autocertManager *autocert.Manager
allowedHosts map[string]bool
mu sync.RWMutex
}
// LetsEncryptConfig holds Let's Encrypt certificate manager configuration
type LetsEncryptConfig struct {
// Email for Let's Encrypt registration (required)
Email string
// CertCacheDir is the directory to cache certificates
CertCacheDir string
}
// NewLetsEncrypt creates a new Let's Encrypt certificate manager
func NewLetsEncrypt(config LetsEncryptConfig) *LetsEncryptManager {
m := &LetsEncryptManager{
allowedHosts: make(map[string]bool),
}
m.autocertManager = &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: m.hostPolicy,
Cache: autocert.DirCache(config.CertCacheDir),
Email: config.Email,
RenewBefore: 0, // Use default 30 days prior to expiration
}
log.Info("Let's Encrypt certificate manager initialized")
return m
}
// IsEnabled returns whether certificate management is enabled
func (m *LetsEncryptManager) IsEnabled() bool {
return true
}
// AddDomain adds a domain to the allowed hosts list
func (m *LetsEncryptManager) AddDomain(domain string) {
m.mu.Lock()
defer m.mu.Unlock()
m.allowedHosts[domain] = true
log.Infof("Added domain to Let's Encrypt manager: %s", domain)
}
// RemoveDomain removes a domain from the allowed hosts list
func (m *LetsEncryptManager) RemoveDomain(domain string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.allowedHosts, domain)
log.Infof("Removed domain from Let's Encrypt manager: %s", domain)
}
// IssueCertificate eagerly issues a Let's Encrypt certificate for a domain
func (m *LetsEncryptManager) IssueCertificate(ctx context.Context, domain string) error {
log.Infof("Issuing Let's Encrypt certificate for domain: %s", domain)
hello := &tls.ClientHelloInfo{
ServerName: domain,
}
cert, err := m.autocertManager.GetCertificate(hello)
if err != nil {
return fmt.Errorf("failed to issue certificate for domain %s: %w", domain, err)
}
log.Infof("Successfully issued Let's Encrypt certificate for domain: %s (expires: %s)",
domain, cert.Leaf.NotAfter.Format(time.RFC3339))
return nil
}
// TLSConfig returns the TLS configuration for the HTTPS server
func (m *LetsEncryptManager) TLSConfig() *tls.Config {
return m.autocertManager.TLSConfig()
}
// HTTPHandler returns the HTTP handler for ACME challenges
func (m *LetsEncryptManager) HTTPHandler(fallback http.Handler) http.Handler {
return m.autocertManager.HTTPHandler(fallback)
}
// hostPolicy validates that a requested host is in the allowed hosts list
func (m *LetsEncryptManager) hostPolicy(ctx context.Context, host string) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.allowedHosts[host] {
log.Debugf("ACME challenge accepted for domain: %s", host)
return nil
}
log.Warnf("ACME challenge rejected for unconfigured domain: %s", host)
return fmt.Errorf("host %s not configured", host)
}

View File

@@ -1,157 +0,0 @@
package certmanager
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// SelfSignedManager handles self-signed certificate generation for local testing
type SelfSignedManager struct {
certificates map[string]*tls.Certificate // domain -> certificate cache
mu sync.RWMutex
}
// NewSelfSigned creates a new self-signed certificate manager
func NewSelfSigned() *SelfSignedManager {
log.Info("Self-signed certificate manager initialized")
return &SelfSignedManager{
certificates: make(map[string]*tls.Certificate),
}
}
// IsEnabled returns whether certificate management is enabled
func (m *SelfSignedManager) IsEnabled() bool {
return true
}
// AddDomain adds a domain to the manager (no-op for self-signed, but maintains interface)
func (m *SelfSignedManager) AddDomain(domain string) {
log.Infof("Added domain to self-signed manager: %s", domain)
}
// RemoveDomain removes a domain from the manager
func (m *SelfSignedManager) RemoveDomain(domain string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.certificates, domain)
log.Infof("Removed domain from self-signed manager: %s", domain)
}
// IssueCertificate generates and caches a self-signed certificate for a domain
func (m *SelfSignedManager) IssueCertificate(ctx context.Context, domain string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.certificates[domain]; exists {
log.Debugf("Self-signed certificate already exists for domain: %s", domain)
return nil
}
cert, err := m.generateCertificate(domain)
if err != nil {
return err
}
m.certificates[domain] = cert
return nil
}
// TLSConfig returns the TLS configuration for the HTTPS server
func (m *SelfSignedManager) TLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: m.getCertificate,
}
}
// HTTPHandler returns the fallback handler (no ACME challenges for self-signed)
func (m *SelfSignedManager) HTTPHandler(fallback http.Handler) http.Handler {
return fallback
}
// getCertificate returns a self-signed certificate for the requested domain
func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.mu.RLock()
cert, exists := m.certificates[hello.ServerName]
m.mu.RUnlock()
if exists {
return cert, nil
}
log.Infof("Generating self-signed certificate on-demand for: %s", hello.ServerName)
newCert, err := m.generateCertificate(hello.ServerName)
if err != nil {
return nil, err
}
m.mu.Lock()
m.certificates[hello.ServerName] = newCert
m.mu.Unlock()
return newCert, nil
}
// generateCertificate generates a self-signed certificate for a domain
func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"NetBird Local Development"},
CommonName: domain,
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{domain},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
tlsCert := &tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: priv,
Leaf: cert,
}
log.Infof("Generated self-signed certificate for domain: %s (expires: %s)",
domain, cert.NotAfter.Format(time.RFC3339))
return tlsCert, nil
}

View File

@@ -1,103 +0,0 @@
package reverseproxy
import (
"net/http"
"net/http/httputil"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
)
// Config holds the reverse proxy configuration
type Config struct {
// ListenAddress is the address to listen on for HTTPS (default ":443")
ListenAddress string `env:"NB_REVERSE_PROXY_LISTEN_ADDRESS" envDefault:":443" json:"listen_address"`
// ManagementURL is the URL of the management server
ManagementURL string `env:"NB_REVERSE_PROXY_MANAGEMENT_URL" json:"management_url"`
// HTTPListenAddress is the address for HTTP (default ":80")
// Used for ACME challenges (Let's Encrypt HTTP-01 challenge)
HTTPListenAddress string `env:"NB_REVERSE_PROXY_HTTP_LISTEN_ADDRESS" envDefault:":80" json:"http_listen_address"`
// CertMode specifies certificate mode: "letsencrypt" or "selfsigned" (default: "letsencrypt")
// "letsencrypt" - Uses Let's Encrypt for production certificates (requires public domain)
// "selfsigned" - Generates self-signed certificates for local testing
CertMode string `env:"NB_REVERSE_PROXY_CERT_MODE" envDefault:"letsencrypt" json:"cert_mode"`
// TLSEmail is the email for Let's Encrypt registration (required for letsencrypt mode)
TLSEmail string `env:"NB_REVERSE_PROXY_TLS_EMAIL" json:"tls_email"`
// CertCacheDir is the directory to cache certificates (for letsencrypt mode, default "./certs")
CertCacheDir string `env:"NB_REVERSE_PROXY_CERT_CACHE_DIR" envDefault:"./certs" json:"cert_cache_dir"`
// OIDCConfig is the global OIDC/OAuth configuration for authentication
// This is shared across all routes that use Bearer authentication
// If nil, routes with Bearer auth will fail to initialize
OIDCConfig *oidc.Config `json:"oidc_config"`
}
// RouteConfig defines a routing configuration
type RouteConfig struct {
// ID is a unique identifier for this route
ID string
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
Domain string
// PathMappings defines paths that should be forwarded to specific ports
// Key is the path prefix (e.g., "/", "/api", "/admin")
// Value is the target IP:port (e.g., "192.168.1.100:3000")
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
PathMappings map[string]string
SetupKey string
nbClient *embed.Client
// AuthConfig is optional authentication configuration for this route
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
// If nil, requests pass through without authentication
AuthConfig *auth.Config
// AuthRejectResponse is an optional custom response for authentication failures
// If nil, returns 401 Unauthorized with WWW-Authenticate header
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
}
// routeEntry represents a compiled route with its proxy
type routeEntry struct {
routeConfig *RouteConfig
path string
target string
proxy *httputil.ReverseProxy
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
}
// RequestDataCallback is called for each proxied request with metrics
type RequestDataCallback func(data RequestData)
// RequestData contains metrics for a proxied request
type RequestData struct {
ServiceID string
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIP string
AuthMechanism string
UserID string
AuthSuccess bool
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}

View File

@@ -1,224 +0,0 @@
package reverseproxy
import (
"net/http"
"net/http/httputil"
"net/url"
"sort"
"strings"
"time"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth"
)
// buildHandler creates the main HTTP handler with router for static endpoints
func (p *Proxy) buildHandler() http.Handler {
router := mux.NewRouter()
// Register static endpoints
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
// Catch-all handler for dynamic proxy routing
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
return router
}
// handleProxyRequest handles all dynamic proxy requests
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
routeEntry := p.findRoute(r.Host, r.URL.Path)
if routeEntry == nil {
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
http.NotFound(w, r)
return
}
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
routeEntry.handler.ServeHTTP(rw, r)
if p.requestCallback != nil {
duration := time.Since(startTime)
host := r.Host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// TODO: extract logging data
authMechanism := r.Header.Get("X-Auth-Method")
if authMechanism == "" {
authMechanism = "none"
}
userID := r.Header.Get("X-Auth-User-ID")
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
sourceIP := extractSourceIP(r)
data := RequestData{
ServiceID: routeEntry.routeConfig.ID,
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(rw.statusCode),
SourceIP: sourceIP,
AuthMechanism: authMechanism,
UserID: userID,
AuthSuccess: authSuccess,
}
p.requestCallback(data)
}
}
// findRoute finds the matching route for a given host and path
func (p *Proxy) findRoute(host, path string) *routeEntry {
p.mu.RLock()
defer p.mu.RUnlock()
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
routeConfig, exists := p.routes[host]
if !exists {
return nil
}
var entries []*routeEntry
for routePath, target := range routeConfig.PathMappings {
proxy := p.createProxy(routeConfig, target)
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
var authType string
if routeConfig.AuthConfig.BasicAuth != nil {
authType = "basic_auth"
} else if routeConfig.AuthConfig.PIN != nil {
authType = "pin"
} else if routeConfig.AuthConfig.Bearer != nil {
authType = "bearer_jwt"
}
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
"auth_type": authType,
}).Debug("Auth middleware enabled for route")
} else {
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
}).Debug("No authentication configured for route")
}
entries = append(entries, &routeEntry{
routeConfig: routeConfig,
path: routePath,
target: target,
proxy: proxy,
handler: handler,
})
}
// Sort by path specificity (longest first)
sort.Slice(entries, func(i, j int) bool {
pi, pj := entries[i].path, entries[j].path
// Empty string or "/" goes last (catch-all)
if pi == "" || pi == "/" {
return false
}
if pj == "" || pj == "/" {
return true
}
return len(pi) > len(pj)
})
// Find first matching entry
for _, entry := range entries {
if entry.path == "" || entry.path == "/" {
// Catch-all route
return entry
}
if strings.HasPrefix(path, entry.path) {
return entry
}
}
return nil
}
// createProxy creates a reverse proxy for a target with the route's connection
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
targetURL, err := url.Parse("http://" + target)
if err != nil {
log.Errorf("Failed to parse target URL %s: %v", target, err)
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
},
}
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Transport = &http.Transport{
DialContext: routeConfig.nbClient.DialContext,
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
IdleConnTimeout: 0,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
return proxy
}
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
if p.oidcHandler == nil {
log.Error("OIDC callback received but no OIDC handler configured")
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
return
}
handler := p.oidcHandler.HandleCallback()
handler(w, r)
}
// extractSourceIP extracts the source IP from the request
func extractSourceIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}

View File

@@ -1,169 +0,0 @@
package reverseproxy
import (
"context"
"fmt"
"net/http"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
"github.com/netbirdio/netbird/proxy/internal/reverseproxy/certmanager"
)
// Proxy wraps a reverse proxy with dynamic routing
type Proxy struct {
config Config
mu sync.RWMutex
routes map[string]*RouteConfig // key is host/domain (for fast O(1) lookup)
server *http.Server
httpServer *http.Server
certManager certmanager.Manager
isRunning bool
requestCallback RequestDataCallback
oidcHandler *oidc.Handler
}
// New creates a new reverse proxy
func New(config Config) (*Proxy, error) {
if config.ListenAddress == "" {
config.ListenAddress = ":443"
}
if config.HTTPListenAddress == "" {
config.HTTPListenAddress = ":80"
}
if config.CertCacheDir == "" {
config.CertCacheDir = "./certs"
}
if config.CertMode == "" {
config.CertMode = "letsencrypt"
}
if config.CertMode == "letsencrypt" && config.TLSEmail == "" {
return nil, fmt.Errorf("TLSEmail is required for letsencrypt mode")
}
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
config.OIDCConfig.SessionCookieName = "auth_session"
}
var certMgr certmanager.Manager
if config.CertMode == "selfsigned" {
// HTTPS with self-signed certificates (for local testing)
certMgr = certmanager.NewSelfSigned()
} else {
// HTTPS with Let's Encrypt (for production)
certMgr = certmanager.NewLetsEncrypt(certmanager.LetsEncryptConfig{
Email: config.TLSEmail,
CertCacheDir: config.CertCacheDir,
})
}
p := &Proxy{
config: config,
routes: make(map[string]*RouteConfig),
certManager: certMgr,
isRunning: false,
}
if config.OIDCConfig != nil {
stateStore := oidc.NewStateStore()
p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore)
}
return p, nil
}
// Start starts the reverse proxy server (non-blocking)
func (p *Proxy) Start() error {
p.mu.Lock()
if p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy already running")
}
p.isRunning = true
p.mu.Unlock()
handler := p.buildHandler()
return p.startHTTPS(handler)
}
// startHTTPS starts the proxy with HTTPS
func (p *Proxy) startHTTPS(handler http.Handler) error {
p.httpServer = &http.Server{
Addr: p.config.HTTPListenAddress,
Handler: p.certManager.HTTPHandler(nil),
}
go func() {
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTP server error: %v", err)
}
}()
p.server = &http.Server{
Addr: p.config.ListenAddress,
Handler: handler,
TLSConfig: p.certManager.TLSConfig(),
}
go func() {
log.Infof("Starting HTTPS reverse proxy server on %s", p.config.ListenAddress)
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTPS server failed: %v", err)
}
}()
return nil
}
// Stop gracefully stops the reverse proxy server
func (p *Proxy) Stop(ctx context.Context) error {
p.mu.Lock()
if !p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy not running")
}
p.isRunning = false
p.mu.Unlock()
log.Info("Stopping reverse proxy server...")
if p.httpServer != nil {
if err := p.httpServer.Shutdown(ctx); err != nil {
log.Errorf("Error shutting down HTTP server: %v", err)
}
}
if p.server != nil {
if err := p.server.Shutdown(ctx); err != nil {
return fmt.Errorf("error shutting down server: %w", err)
}
}
log.Info("Reverse proxy server stopped")
return nil
}
// IsRunning returns whether the proxy is running
func (p *Proxy) IsRunning() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.isRunning
}
// SetRequestCallback sets the callback for request metrics
func (p *Proxy) SetRequestCallback(callback RequestDataCallback) {
p.mu.Lock()
defer p.mu.Unlock()
p.requestCallback = callback
}
// GetConfig returns the proxy configuration
func (p *Proxy) GetConfig() Config {
return p.config
}

View File

@@ -1,149 +0,0 @@
package reverseproxy
import (
"context"
"fmt"
"io"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
)
const (
clientStartupTimeout = 30 * time.Second
)
// AddRoute adds a new route to the proxy
func (p *Proxy) AddRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
if len(route.PathMappings) == 0 {
return fmt.Errorf("route must have at least one path mapping")
}
if route.SetupKey == "" {
return fmt.Errorf("route setup key is required")
}
p.mu.Lock()
defer p.mu.Unlock()
if _, exists := p.routes[route.Domain]; exists {
return fmt.Errorf("route for domain %s already exists", route.Domain)
}
client, err := embed.New(embed.Options{DeviceName: fmt.Sprintf("ingress-%s", route.ID), ManagementURL: p.config.ManagementURL, SetupKey: route.SetupKey, LogOutput: io.Discard})
if err != nil {
return fmt.Errorf("failed to create embedded client for route %s: %v", route.ID, err)
}
ctx, _ := context.WithTimeout(context.Background(), clientStartupTimeout)
err = client.Start(ctx)
if err != nil {
return fmt.Errorf("failed to start embedded client for route %s: %v", route.ID, err)
}
route.nbClient = client
p.routes[route.Domain] = route
p.certManager.AddDomain(route.Domain)
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Added route")
go func(domain string) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
if err := p.certManager.IssueCertificate(ctx, domain); err != nil {
log.Errorf("Failed to issue certificate: %v", err)
// TODO: Better error feedback mechanism
}
}(route.Domain)
return nil
}
// RemoveRoute removes a route by domain
func (p *Proxy) RemoveRoute(domain string) error {
p.mu.Lock()
defer p.mu.Unlock()
if _, exists := p.routes[domain]; !exists {
return fmt.Errorf("route for domain %s not found", domain)
}
delete(p.routes, domain)
p.certManager.RemoveDomain(domain)
log.Infof("Removed route for domain: %s", domain)
return nil
}
// UpdateRoute updates an existing route
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
p.mu.Lock()
defer p.mu.Unlock()
if _, exists := p.routes[route.Domain]; !exists {
return fmt.Errorf("route for domain %s not found", route.Domain)
}
p.routes[route.Domain] = route
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Updated route")
return nil
}
// ListRoutes returns a list of all configured domains
func (p *Proxy) ListRoutes() []string {
p.mu.RLock()
defer p.mu.RUnlock()
domains := make([]string, 0, len(p.routes))
for domain := range p.routes {
domains = append(domains, domain)
}
return domains
}
// GetRoute returns a route configuration by domain
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
p.mu.RLock()
defer p.mu.RUnlock()
route, exists := p.routes[domain]
if !exists {
return nil, fmt.Errorf("route for domain %s not found", domain)
}
return route, nil
}

View File

@@ -0,0 +1,61 @@
package roundtrip
import (
"fmt"
"net/http"
"sync"
"github.com/netbirdio/netbird/client/embed"
)
const deviceNamePrefix = "ingress-"
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
type NetBird struct {
mgmtAddr string
clientsMux sync.RWMutex
clients map[string]*http.Client
}
func NewNetBird(mgmtAddr string) *NetBird {
return &NetBird{
mgmtAddr: mgmtAddr,
clients: make(map[string]*http.Client),
}
}
func (n *NetBird) AddPeer(domain, key string) error {
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + domain,
ManagementURL: n.mgmtAddr,
SetupKey: key,
})
if err != nil {
return fmt.Errorf("create netbird client: %w", err)
}
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.clients[domain] = client.NewHTTPClient()
return nil
}
func (n *NetBird) RemovePeer(domain string) {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
delete(n.clients, domain)
}
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
n.clientsMux.RLock()
client, exists := n.clients[req.Host]
// Immediately unlock after retrieval here rather than defer to avoid
// the call to client.Do blocking other clients being used whilst one
// is in use.
n.clientsMux.RUnlock()
if !exists {
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host)
}
return client.Do(req)
}

View File

@@ -1,13 +0,0 @@
package main
import (
"os"
"github.com/netbirdio/netbird/proxy/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}

View File

@@ -1,325 +0,0 @@
package grpc
import (
"context"
"fmt"
"io"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/shared/management/proto"
)
const (
reconnectInterval = 5 * time.Second
proxyVersion = "0.1.0"
)
// ServiceUpdateHandler is called when services are added/updated/removed
type ServiceUpdateHandler func(update *proto.ServiceUpdate) error
// Client manages the gRPC connection to management server
type Client struct {
proxyID string
managementURL string
conn *grpc.ClientConn
stream proto.ProxyService_StreamClient
serviceUpdateHandler ServiceUpdateHandler
accessLogChan chan *proto.ProxyRequestData
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
connected bool
}
// ClientConfig holds client configuration
type ClientConfig struct {
ProxyID string
ManagementURL string
ServiceUpdateHandler ServiceUpdateHandler
}
// NewClient creates a new gRPC client for proxy-management communication
func NewClient(config ClientConfig) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
proxyID: config.ProxyID,
managementURL: config.ManagementURL,
serviceUpdateHandler: config.ServiceUpdateHandler,
accessLogChan: make(chan *proto.ProxyRequestData, 1000),
ctx: ctx,
cancel: cancel,
}
}
// Start connects to management server and maintains connection
func (c *Client) Start() error {
go c.connectionLoop()
return nil
}
// Stop closes the connection
func (c *Client) Stop() error {
c.cancel()
c.mu.Lock()
defer c.mu.Unlock()
if c.stream != nil {
// Try to close stream gracefully
_ = c.stream.CloseSend()
}
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// SendAccessLog queues an access log to be sent to management
func (c *Client) SendAccessLog(log *proto.ProxyRequestData) {
select {
case c.accessLogChan <- log:
default:
// Channel full, drop log
}
}
// IsConnected returns whether client is connected to management
func (c *Client) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// connectionLoop maintains connection to management server
func (c *Client) connectionLoop() {
for {
select {
case <-c.ctx.Done():
return
default:
}
log.Infof("Connecting to management server at %s", c.managementURL)
if err := c.connect(); err != nil {
log.Errorf("Failed to connect to management: %v", err)
c.setConnected(false)
select {
case <-c.ctx.Done():
return
case <-time.After(reconnectInterval):
continue
}
}
// Handle connection
if err := c.handleConnection(); err != nil {
log.Errorf("Connection error: %v", err)
c.setConnected(false)
}
// Reconnect after delay
select {
case <-c.ctx.Done():
return
case <-time.After(reconnectInterval):
}
}
}
// connect establishes connection to management server
func (c *Client) connect() error {
// Strip scheme from URL if present (gRPC doesn't use http:// or https://)
target := c.managementURL
target = strings.TrimPrefix(target, "http://")
target = strings.TrimPrefix(target, "https://")
// Create gRPC connection
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO: Add TLS
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 20 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
}
conn, err := grpc.Dial(target, opts...)
if err != nil {
return fmt.Errorf("failed to dial: %w", err)
}
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
// Create stream
client := proto.NewProxyServiceClient(conn)
stream, err := client.Stream(c.ctx)
if err != nil {
conn.Close()
return fmt.Errorf("failed to create stream: %w", err)
}
c.mu.Lock()
c.stream = stream
c.mu.Unlock()
// Send ProxyHello
hello := &proto.ProxyMessage{
Payload: &proto.ProxyMessage_Hello{
Hello: &proto.ProxyHello{
ProxyId: c.proxyID,
Version: proxyVersion,
StartedAt: timestamppb.Now(),
},
},
}
if err := stream.Send(hello); err != nil {
conn.Close()
return fmt.Errorf("failed to send hello: %w", err)
}
c.setConnected(true)
log.Info("Successfully connected to management server")
return nil
}
// handleConnection manages the active connection
func (c *Client) handleConnection() error {
errChan := make(chan error, 2)
// Start sender goroutine
go c.sender(errChan)
// Start receiver goroutine
go c.receiver(errChan)
// Wait for error
return <-errChan
}
// sender sends messages to management
func (c *Client) sender(errChan chan<- error) {
for {
select {
case <-c.ctx.Done():
errChan <- c.ctx.Err()
return
case accessLog := <-c.accessLogChan:
msg := &proto.ProxyMessage{
Payload: &proto.ProxyMessage_RequestData{
RequestData: accessLog,
},
}
c.mu.RLock()
stream := c.stream
c.mu.RUnlock()
if stream == nil {
continue
}
if err := stream.Send(msg); err != nil {
log.Errorf("Failed to send access log: %v", err)
errChan <- err
return
}
}
}
}
// receiver receives messages from management
func (c *Client) receiver(errChan chan<- error) {
for {
c.mu.RLock()
stream := c.stream
c.mu.RUnlock()
if stream == nil {
errChan <- fmt.Errorf("stream is nil")
return
}
msg, err := stream.Recv()
if err == io.EOF {
log.Info("Management server closed connection")
errChan <- io.EOF
return
}
if err != nil {
log.Errorf("Failed to receive: %v", err)
errChan <- err
return
}
// Handle message
switch payload := msg.GetPayload().(type) {
case *proto.ManagementMessage_Snapshot:
c.handleSnapshot(payload.Snapshot)
case *proto.ManagementMessage_Update:
c.handleServiceUpdate(payload.Update)
default:
log.Warnf("Received unknown message type")
}
}
}
// handleSnapshot processes initial services snapshot
func (c *Client) handleSnapshot(snapshot *proto.ServicesSnapshot) {
log.Infof("Received services snapshot with %d services", len(snapshot.Services))
if c.serviceUpdateHandler == nil {
log.Warn("No service update handler configured")
return
}
// Process each service as a CREATED update
for _, service := range snapshot.Services {
update := &proto.ServiceUpdate{
Type: proto.ServiceUpdate_CREATED,
Service: service,
ServiceId: service.Id,
}
if err := c.serviceUpdateHandler(update); err != nil {
log.Errorf("Failed to handle service %s: %v", service.Id, err)
}
}
}
// handleServiceUpdate processes incremental service update
func (c *Client) handleServiceUpdate(update *proto.ServiceUpdate) {
log.Infof("Received service update: %s %s", update.Type, update.ServiceId)
if c.serviceUpdateHandler == nil {
log.Warn("No service update handler configured")
return
}
if err := c.serviceUpdateHandler(update); err != nil {
log.Errorf("Failed to handle service update: %v", err)
}
}
// setConnected updates connected status
func (c *Client) setConnected(connected bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.connected = connected
}

View File

@@ -1,214 +0,0 @@
package proxy
import (
"encoding/json"
"errors"
"fmt"
"os"
"reflect"
"time"
"github.com/caarlos0/env/v11"
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
)
var (
ErrFailedToParseConfig = errors.New("failed to parse config from env")
)
// Duration is a time.Duration that can be unmarshaled from JSON as a string
type Duration time.Duration
// UnmarshalJSON implements json.Unmarshaler for Duration
func (d *Duration) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
parsed, err := time.ParseDuration(s)
if err != nil {
return err
}
*d = Duration(parsed)
return nil
}
// MarshalJSON implements json.Marshaler for Duration
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Duration(d).String())
}
// ToDuration converts Duration to time.Duration
func (d Duration) ToDuration() time.Duration {
return time.Duration(d)
}
// Config holds the configuration for the reverse proxy server
type Config struct {
// ReadTimeout is the maximum duration for reading the entire request, including the body
ReadTimeout time.Duration `env:"NB_PROXY_READ_TIMEOUT" envDefault:"30s" json:"read_timeout"`
// WriteTimeout is the maximum duration before timing out writes of the response
WriteTimeout time.Duration `env:"NB_PROXY_WRITE_TIMEOUT" envDefault:"30s" json:"write_timeout"`
// IdleTimeout is the maximum amount of time to wait for the next request when keep-alives are enabled
IdleTimeout time.Duration `env:"NB_PROXY_IDLE_TIMEOUT" envDefault:"60s" json:"idle_timeout"`
// ShutdownTimeout is the maximum duration to wait for graceful shutdown
ShutdownTimeout time.Duration `env:"NB_PROXY_SHUTDOWN_TIMEOUT" envDefault:"10s" json:"shutdown_timeout"`
// LogLevel sets the logging verbosity (debug, info, warn, error)
LogLevel string `env:"NB_PROXY_LOG_LEVEL" envDefault:"info" json:"log_level"`
// GRPCListenAddress is the address for the gRPC control server
GRPCListenAddress string `env:"NB_PROXY_GRPC_LISTEN_ADDRESS" envDefault:":50051" json:"grpc_listen_address"`
// ProxyID is a unique identifier for this proxy instance
ProxyID string `env:"NB_PROXY_ID" envDefault:"" json:"proxy_id"`
// EnableGRPC enables the gRPC control server
EnableGRPC bool `env:"NB_PROXY_ENABLE_GRPC" envDefault:"false" json:"enable_grpc"`
// Reverse Proxy Configuration
ReverseProxy reverseproxy.Config `json:"reverse_proxy"`
}
// ParseAndLoad parses configuration from environment variables
func ParseAndLoad() (Config, error) {
var cfg Config
if err := env.Parse(&cfg); err != nil {
return cfg, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
}
if err := cfg.Validate(); err != nil {
return cfg, fmt.Errorf("invalid config: %w", err)
}
return cfg, nil
}
// LoadFromFile reads configuration from a JSON file
func LoadFromFile(path string) (Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return Config{}, fmt.Errorf("failed to read config file: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return Config{}, fmt.Errorf("failed to parse config file: %w", err)
}
if err := cfg.Validate(); err != nil {
return Config{}, fmt.Errorf("invalid config: %w", err)
}
return cfg, nil
}
// LoadFromFileOrEnv loads configuration from a file if path is provided, otherwise from environment variables
func LoadFromFileOrEnv(configPath string) (Config, error) {
var cfg Config
// If config file is provided, load it first
if configPath != "" {
fileCfg, err := LoadFromFile(configPath)
if err != nil {
return Config{}, fmt.Errorf("failed to load config from file: %w", err)
}
cfg = fileCfg
} else {
if err := env.Parse(&cfg); err != nil {
return Config{}, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
}
}
if err := cfg.Validate(); err != nil {
return Config{}, fmt.Errorf("invalid config: %w", err)
}
return cfg, nil
}
// UnmarshalJSON implements custom JSON unmarshaling with automatic duration parsing
func (c *Config) UnmarshalJSON(data []byte) error {
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
jsonFieldName := jsonTag
if idx := len(jsonTag); idx > 0 {
for j, c := range jsonTag {
if c == ',' {
jsonFieldName = jsonTag[:j]
break
}
}
}
rawValue, exists := raw[jsonFieldName]
if !exists {
continue
}
if field.Type() == reflect.TypeOf(time.Duration(0)) {
if strValue, ok := rawValue.(string); ok {
duration, err := time.ParseDuration(strValue)
if err != nil {
return fmt.Errorf("invalid duration for field %s: %w", jsonFieldName, err)
}
field.Set(reflect.ValueOf(duration))
} else {
return fmt.Errorf("field %s must be a duration string", jsonFieldName)
}
} else {
fieldData, err := json.Marshal(rawValue)
if err != nil {
return fmt.Errorf("failed to marshal field %s: %w", jsonFieldName, err)
}
if field.CanSet() {
newVal := reflect.New(field.Type())
if err := json.Unmarshal(fieldData, newVal.Interface()); err != nil {
return fmt.Errorf("failed to unmarshal field %s: %w", jsonFieldName, err)
}
field.Set(newVal.Elem())
}
}
}
return nil
}
// Validate checks if the configuration is valid
func (c *Config) Validate() error {
validLogLevels := map[string]bool{
"debug": true,
"info": true,
"warn": true,
"error": true,
}
if !validLogLevels[c.LogLevel] {
return fmt.Errorf("invalid log_level: %s (must be debug, info, warn, or error)", c.LogLevel)
}
return nil
}

View File

@@ -1,592 +0,0 @@
package proxy
import (
"context"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/auth/methods"
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
grpcpkg "github.com/netbirdio/netbird/proxy/pkg/grpc"
pb "github.com/netbirdio/netbird/shared/management/proto"
)
// Server represents the reverse proxy server with integrated gRPC client
type Server struct {
config Config
grpcClient *grpcpkg.Client
proxy *reverseproxy.Proxy
mu sync.RWMutex
isRunning bool
grpcRunning bool
shutdownCtx context.Context
cancelFunc context.CancelFunc
// Statistics for gRPC reporting
stats *Stats
// Track exposed services and their peer configs
exposedServices map[string]*ExposedServiceConfig
}
// Stats holds proxy statistics
type Stats struct {
mu sync.RWMutex
totalRequests uint64
activeConns uint64
bytesSent uint64
bytesReceived uint64
}
// ExposedServiceConfig holds the configuration for an exposed service
type ExposedServiceConfig struct {
ServiceID string
PeerConfig *PeerConfig
UpstreamConfig *UpstreamConfig
}
// PeerConfig holds WireGuard peer configuration
type PeerConfig struct {
PeerID string
PublicKey string
AllowedIPs []string
Endpoint string
TunnelIP string // The WireGuard tunnel IP to route traffic to
}
// UpstreamConfig holds reverse proxy upstream configuration
type UpstreamConfig struct {
Domain string
PathMappings map[string]string // path -> port mapping (relative to tunnel IP)
}
// NewServer creates a new reverse proxy server instance
func NewServer(config Config) (*Server, error) {
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
shutdownCtx, cancelFunc := context.WithCancel(context.Background())
server := &Server{
config: config,
isRunning: false,
grpcRunning: false,
shutdownCtx: shutdownCtx,
cancelFunc: cancelFunc,
stats: &Stats{},
exposedServices: make(map[string]*ExposedServiceConfig),
}
proxy, err := reverseproxy.New(config.ReverseProxy)
if err != nil {
return nil, fmt.Errorf("failed to create reverse proxy: %w", err)
}
server.proxy = proxy
if config.ReverseProxy.ManagementURL == "" {
return nil, fmt.Errorf("management URL is required")
}
grpcClient := grpcpkg.NewClient(grpcpkg.ClientConfig{
ProxyID: config.ProxyID,
ManagementURL: config.ReverseProxy.ManagementURL,
ServiceUpdateHandler: server.handleServiceUpdate,
})
server.grpcClient = grpcClient
// Set request data callback to send access logs to management
proxy.SetRequestCallback(func(data reverseproxy.RequestData) {
accessLog := &pb.ProxyRequestData{
Timestamp: timestamppb.Now(),
ServiceId: data.ServiceID,
Host: data.Host,
Path: data.Path,
DurationMs: data.DurationMs,
Method: data.Method,
ResponseCode: data.ResponseCode,
SourceIp: data.SourceIP,
AuthMechanism: data.AuthMechanism,
UserId: data.UserID,
AuthSuccess: data.AuthSuccess,
}
server.grpcClient.SendAccessLog(accessLog)
})
return server, nil
}
// Start starts the reverse proxy server and optionally the gRPC control server
func (s *Server) Start() error {
s.mu.Lock()
if s.isRunning {
s.mu.Unlock()
return fmt.Errorf("server is already running")
}
s.isRunning = true
s.mu.Unlock()
log.Infof("Starting proxy reverse proxy server on %s", s.config.ReverseProxy.ListenAddress)
if err := s.proxy.Start(); err != nil {
s.mu.Lock()
s.isRunning = false
s.mu.Unlock()
return fmt.Errorf("failed to start reverse proxy: %w", err)
}
s.mu.Lock()
s.grpcRunning = true
s.mu.Unlock()
if err := s.grpcClient.Start(); err != nil {
s.mu.Lock()
s.isRunning = false
s.grpcRunning = false
s.mu.Unlock()
return fmt.Errorf("failed to start gRPC client: %w", err)
}
log.Info("Proxy started and connected to management")
log.Info("Waiting for service configurations from management...")
<-s.shutdownCtx.Done()
return nil
}
// Stop gracefully shuts down both proxy and gRPC servers
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
if !s.isRunning {
s.mu.Unlock()
return fmt.Errorf("server is not running")
}
s.mu.Unlock()
log.Info("Shutting down servers gracefully...")
// If no context provided, use the server's shutdown timeout
if ctx == nil {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), s.config.ShutdownTimeout)
defer cancel()
}
var proxyErr, grpcErr error
// Stop gRPC client first
if s.grpcRunning {
if err := s.grpcClient.Stop(); err != nil {
grpcErr = fmt.Errorf("gRPC client shutdown failed: %w", err)
log.Error(grpcErr)
}
s.mu.Lock()
s.grpcRunning = false
s.mu.Unlock()
}
// Shutdown reverse proxy
if err := s.proxy.Stop(ctx); err != nil {
proxyErr = fmt.Errorf("reverse proxy shutdown failed: %w", err)
log.Error(proxyErr)
}
s.mu.Lock()
s.isRunning = false
s.mu.Unlock()
if proxyErr != nil {
return proxyErr
}
if grpcErr != nil {
return grpcErr
}
log.Info("All servers stopped successfully")
return nil
}
// IsRunning returns whether the server is currently running
func (s *Server) IsRunning() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.isRunning
}
// GetConfig returns a copy of the server configuration
func (s *Server) GetConfig() Config {
return s.config
}
// handleServiceUpdate processes service updates from management
func (s *Server) handleServiceUpdate(update *pb.ServiceUpdate) error {
log.WithFields(log.Fields{
"service_id": update.ServiceId,
"type": update.Type.String(),
}).Info("Received service update from management")
switch update.Type {
case pb.ServiceUpdate_CREATED:
if update.Service == nil {
return fmt.Errorf("service config is nil for CREATED update")
}
return s.addServiceFromProto(update.Service)
case pb.ServiceUpdate_UPDATED:
if update.Service == nil {
return fmt.Errorf("service config is nil for UPDATED update")
}
return s.updateServiceFromProto(update.Service)
case pb.ServiceUpdate_REMOVED:
return s.removeService(update.ServiceId)
default:
return fmt.Errorf("unknown service update type: %v", update.Type)
}
}
// addServiceFromProto adds a service from proto config
func (s *Server) addServiceFromProto(serviceConfig *pb.ExposedServiceConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check if service already exists
if _, exists := s.exposedServices[serviceConfig.Id]; exists {
log.Warnf("Service %s already exists, updating instead", serviceConfig.Id)
return s.updateServiceFromProtoLocked(serviceConfig)
}
log.WithFields(log.Fields{
"service_id": serviceConfig.Id,
"domain": serviceConfig.Domain,
}).Info("Adding service from management")
// Convert proto auth config to internal auth config
var authConfig *auth.Config
if serviceConfig.Auth != nil {
authConfig = convertProtoAuthConfig(serviceConfig.Auth)
}
// Add route to proxy
route := &reverseproxy.RouteConfig{
ID: serviceConfig.Id,
Domain: serviceConfig.Domain,
PathMappings: serviceConfig.PathMappings,
AuthConfig: authConfig,
SetupKey: serviceConfig.SetupKey,
}
if err := s.proxy.AddRoute(route); err != nil {
return fmt.Errorf("failed to add route: %w", err)
}
// Store service config (simplified, no peer config for now)
s.exposedServices[serviceConfig.Id] = &ExposedServiceConfig{
ServiceID: serviceConfig.Id,
UpstreamConfig: &UpstreamConfig{
Domain: serviceConfig.Domain,
PathMappings: serviceConfig.PathMappings,
},
}
log.Infof("Service %s added successfully", serviceConfig.Id)
return nil
}
// updateServiceFromProto updates an existing service from proto config
func (s *Server) updateServiceFromProto(serviceConfig *pb.ExposedServiceConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateServiceFromProtoLocked(serviceConfig)
}
func (s *Server) updateServiceFromProtoLocked(serviceConfig *pb.ExposedServiceConfig) error {
log.WithFields(log.Fields{
"service_id": serviceConfig.Id,
"domain": serviceConfig.Domain,
}).Info("Updating service from management")
// Convert proto auth config to internal auth config
var authConfig *auth.Config
if serviceConfig.Auth != nil {
authConfig = convertProtoAuthConfig(serviceConfig.Auth)
}
// Update route in proxy
route := &reverseproxy.RouteConfig{
ID: serviceConfig.Id,
Domain: serviceConfig.Domain,
PathMappings: serviceConfig.PathMappings,
AuthConfig: authConfig,
SetupKey: serviceConfig.SetupKey,
}
if err := s.proxy.UpdateRoute(route); err != nil {
return fmt.Errorf("failed to update route: %w", err)
}
// Update service config
s.exposedServices[serviceConfig.Id] = &ExposedServiceConfig{
ServiceID: serviceConfig.Id,
UpstreamConfig: &UpstreamConfig{
Domain: serviceConfig.Domain,
PathMappings: serviceConfig.PathMappings,
},
}
log.Infof("Service %s updated successfully", serviceConfig.Id)
return nil
}
// removeService removes a service
func (s *Server) removeService(serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
log.WithFields(log.Fields{
"service_id": serviceID,
}).Info("Removing service from management")
// Remove route from proxy
if err := s.proxy.RemoveRoute(serviceID); err != nil {
return fmt.Errorf("failed to remove route: %w", err)
}
// Remove service config
delete(s.exposedServices, serviceID)
log.Infof("Service %s removed successfully", serviceID)
return nil
}
// convertProtoAuthConfig converts proto auth config to internal auth config
func convertProtoAuthConfig(protoAuth *pb.AuthConfig) *auth.Config {
authConfig := &auth.Config{}
switch authType := protoAuth.AuthType.(type) {
case *pb.AuthConfig_BasicAuth:
authConfig.BasicAuth = &methods.BasicAuthConfig{
Username: authType.BasicAuth.Username,
Password: authType.BasicAuth.Password,
}
case *pb.AuthConfig_PinAuth:
authConfig.PIN = &methods.PINConfig{
PIN: authType.PinAuth.Pin,
Header: authType.PinAuth.Header,
}
case *pb.AuthConfig_BearerAuth:
authConfig.Bearer = &methods.BearerConfig{
Enabled: authType.BearerAuth.Enabled,
}
}
return authConfig
}
// Exposed Service Handlers (deprecated - keeping for backwards compatibility)
// handleExposedServiceCreated handles the creation of a new exposed service
func (s *Server) handleExposedServiceCreated(serviceID string, peerConfig *PeerConfig, upstreamConfig *UpstreamConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check if service already exists
if _, exists := s.exposedServices[serviceID]; exists {
return fmt.Errorf("exposed service %s already exists", serviceID)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"peer_id": peerConfig.PeerID,
"tunnel_ip": peerConfig.TunnelIP,
"domain": upstreamConfig.Domain,
}).Info("Creating exposed service")
// TODO: Create WireGuard tunnel for peer
// 1. Initialize WireGuard interface if not already done
// 2. Add peer configuration:
// - Public key: peerConfig.PublicKey
// - Endpoint: peerConfig.Endpoint
// - Allowed IPs: peerConfig.AllowedIPs
// - Persistent keepalive: 25 seconds
// 3. Bring up the WireGuard interface
// 4. Verify tunnel connectivity to peerConfig.TunnelIP
// Example pseudo-code:
// wgClient.AddPeer(&wireguard.PeerConfig{
// PublicKey: peerConfig.PublicKey,
// Endpoint: peerConfig.Endpoint,
// AllowedIPs: peerConfig.AllowedIPs,
// PersistentKeepalive: 25,
// })
// Build path mappings with tunnel IP
pathMappings := make(map[string]string)
for path, port := range upstreamConfig.PathMappings {
// Combine tunnel IP with port
target := fmt.Sprintf("%s:%s", peerConfig.TunnelIP, port)
pathMappings[path] = target
}
// Add route to proxy
route := &reverseproxy.RouteConfig{
ID: serviceID,
Domain: upstreamConfig.Domain,
PathMappings: pathMappings,
}
if err := s.proxy.AddRoute(route); err != nil {
return fmt.Errorf("failed to add route: %w", err)
}
// Store service config
s.exposedServices[serviceID] = &ExposedServiceConfig{
ServiceID: serviceID,
PeerConfig: peerConfig,
UpstreamConfig: upstreamConfig,
}
log.Infof("Exposed service %s created successfully", serviceID)
return nil
}
// handleExposedServiceUpdated handles updates to an existing exposed service
func (s *Server) handleExposedServiceUpdated(serviceID string, peerConfig *PeerConfig, upstreamConfig *UpstreamConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check if service exists
if _, exists := s.exposedServices[serviceID]; !exists {
return fmt.Errorf("exposed service %s not found", serviceID)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"peer_id": peerConfig.PeerID,
"tunnel_ip": peerConfig.TunnelIP,
"domain": upstreamConfig.Domain,
}).Info("Updating exposed service")
// TODO: Update WireGuard tunnel if peer config changed
// Build path mappings with tunnel IP
pathMappings := make(map[string]string)
for path, port := range upstreamConfig.PathMappings {
target := fmt.Sprintf("%s:%s", peerConfig.TunnelIP, port)
pathMappings[path] = target
}
// Update route in proxy
route := &reverseproxy.RouteConfig{
ID: serviceID,
Domain: upstreamConfig.Domain,
PathMappings: pathMappings,
}
if err := s.proxy.UpdateRoute(route); err != nil {
return fmt.Errorf("failed to update route: %w", err)
}
// Update service config
s.exposedServices[serviceID] = &ExposedServiceConfig{
ServiceID: serviceID,
PeerConfig: peerConfig,
UpstreamConfig: upstreamConfig,
}
log.Infof("Exposed service %s updated successfully", serviceID)
return nil
}
// handleExposedServiceRemoved handles the removal of an exposed service
func (s *Server) handleExposedServiceRemoved(serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check if service exists
if _, exists := s.exposedServices[serviceID]; !exists {
return fmt.Errorf("exposed service %s not found", serviceID)
}
log.WithFields(log.Fields{
"service_id": serviceID,
}).Info("Removing exposed service")
// Remove route from proxy
if err := s.proxy.RemoveRoute(serviceID); err != nil {
return fmt.Errorf("failed to remove route: %w", err)
}
// TODO: Remove WireGuard tunnel for peer
// Remove service config
delete(s.exposedServices, serviceID)
log.Infof("Exposed service %s removed successfully", serviceID)
return nil
}
// ListExposedServices returns a list of all exposed service IDs
func (s *Server) ListExposedServices() []string {
s.mu.RLock()
defer s.mu.RUnlock()
services := make([]string, 0, len(s.exposedServices))
for id := range s.exposedServices {
services = append(services, id)
}
return services
}
// GetExposedService returns the configuration for a specific exposed service
func (s *Server) GetExposedService(serviceID string) (*ExposedServiceConfig, error) {
s.mu.RLock()
defer s.mu.RUnlock()
service, exists := s.exposedServices[serviceID]
if !exists {
return nil, fmt.Errorf("exposed service %s not found", serviceID)
}
return service, nil
}
// Stats methods
func (s *Stats) IncrementRequests() {
s.mu.Lock()
defer s.mu.Unlock()
s.totalRequests++
}
func (s *Stats) IncrementActiveConns() {
s.mu.Lock()
defer s.mu.Unlock()
s.activeConns++
}
func (s *Stats) DecrementActiveConns() {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConns > 0 {
s.activeConns--
}
}
func (s *Stats) AddBytesSent(bytes uint64) {
s.mu.Lock()
defer s.mu.Unlock()
s.bytesSent += bytes
}
func (s *Stats) AddBytesReceived(bytes uint64) {
s.mu.Lock()
defer s.mu.Unlock()
s.bytesReceived += bytes
}

View File

@@ -1,56 +0,0 @@
package version
import (
"fmt"
"runtime"
)
var (
// Version is the application version (set via ldflags during build)
Version = "dev"
// Commit is the git commit hash (set via ldflags during build)
Commit = "unknown"
// BuildDate is the build date (set via ldflags during build)
BuildDate = "unknown"
// GoVersion is the Go version used to build the binary
GoVersion = runtime.Version()
)
// Info contains version information
type Info struct {
Version string `json:"version"`
Commit string `json:"commit"`
BuildDate string `json:"build_date"`
GoVersion string `json:"NewSingleHostReverseProxygo_version"`
OS string `json:"os"`
Arch string `json:"arch"`
}
// Get returns the version information
func Get() Info {
return Info{
Version: Version,
Commit: Commit,
BuildDate: BuildDate,
GoVersion: GoVersion,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
}
// String returns a formatted version string
func String() string {
return fmt.Sprintf("Version: %s, Commit: %s, BuildDate: %s, Go: %s",
Version, Commit, BuildDate, GoVersion)
}
// Short returns a short version string
func Short() string {
if Version == "dev" {
return fmt.Sprintf("%s (%s)", Version, Commit[:7])
}
return Version
}

View File

@@ -1,31 +0,0 @@
#!/bin/bash
set -e
# Check if protoc is installed
if ! command -v protoc &> /dev/null; then
echo "Error: protoc is not installed"
echo "Install with: apt-get install -y protobuf-compiler"
exit 1
fi
# Check if protoc-gen-go is installed
if ! command -v protoc-gen-go &> /dev/null; then
echo "Installing protoc-gen-go..."
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
fi
# Check if protoc-gen-go-grpc is installed
if ! command -v protoc-gen-go-grpc &> /dev/null; then
echo "Installing protoc-gen-go-grpc..."
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
fi
echo "Generating protobuf files..."
# Generate Go code from proto files
protoc --go_out=. --go_opt=paths=source_relative \
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
pkg/grpc/proto/proxy.proto
echo "Proto generation complete!"

277
proxy/server.go Normal file
View File

@@ -0,0 +1,277 @@
// Package proxy runs a NetBird proxy server.
// It attempts to do everything it needs to do within the context
// of a single request to the server to try to reduce the amount
// of concurrency coordination that is required. However, it does
// run two additional routines in an error group for handling
// updates from the management server and running a separate
// HTTP server to handle ACME HTTP-01 challenges (if configured).
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"path/filepath"
"time"
"github.com/cloudflare/backoff"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/shared/management/proto"
"google.golang.org/grpc"
)
type errorLog interface {
Error(msg string, args ...any)
ErrorContext(ctx context.Context, msg string, args ...any)
}
type Server struct {
mgmtConn *grpc.ClientConn
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
ErrorLog errorLog
ManagementAddress string
CertificateDirectory string
GenerateACMECertificates bool
ACMEChallengeAddress string
ACMEDirectory string
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if s.ErrorLog == nil {
// If no ErrorLog is specified, then just discard the log output.
s.ErrorLog = slog.New(slog.DiscardHandler)
}
// The very first thing to do should be to connect to the Management server.
// Without this connection, the Proxy cannot do anything.
s.mgmtConn, err = grpc.NewClient(s.ManagementAddress)
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)
}
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
go s.newManagementMappingWorker(ctx, mgmtClient)
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress)
// When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
}
go func() {
if err := s.http.ListenAndServe(); err != nil {
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
for range time.Tick(10 * time.Second) {
s.ErrorLog.ErrorContext(ctx, "ACME HTTP-01 challenge server error", "error", err)
}
}
}()
tlsConfig = s.acme.TLSConfig()
} else {
// Otherwise pull some certificates from expected locations.
cert, err := tls.LoadX509KeyPair(
filepath.Join(s.CertificateDirectory, "tls.crt"),
filepath.Join(s.CertificateDirectory, "tls.key"),
)
if err != nil {
return fmt.Errorf("load provided certificate: %w", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.netbird)
// Configure the authentication middleware.
s.auth = auth.NewMiddleware()
// Configure Access logs to management server.
accessLog := accesslog.NewLogger(mgmtClient, s.ErrorLog)
// Finally, start the reverse proxy.
s.https = &http.Server{
Addr: addr,
Handler: s.auth.Protect(accessLog.Middleware(s.proxy)),
TLSConfig: tlsConfig,
}
return s.https.ListenAndServeTLS("", "")
}
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) func() {
b := backoff.New(0, 0)
return func() {
for {
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{})
if err != nil {
backoffDuration := b.Duration()
s.ErrorLog.ErrorContext(ctx, "Unable to create mapping client to management server, retrying connection after backoff.",
"backoff", backoffDuration,
"error", err)
time.Sleep(backoffDuration)
continue
}
err = s.handleMappingStream(ctx, mappingClient)
backoffDuration := b.Duration()
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, context.DeadlineExceeded):
// Context is telling us that it is time to quit so gracefully exit here.
// No need to log the error as it is a parent context causing this return.
return
case err != nil:
// Log the error and then retry the connection.
s.ErrorLog.ErrorContext(ctx, "Error processing mapping stream from management server, retrying connection after backoff.",
"backoff", backoffDuration,
"error", err)
default:
// TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues.
s.ErrorLog.ErrorContext(ctx, "Management mapping connection terminated by the server, retrying connection after backoff.",
"backoff", backoffDuration)
}
time.Sleep(backoffDuration)
}
}
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient) error {
for {
// Check for context completion to gracefully shutdown.
select {
case <-ctx.Done():
// Shutting down.
return ctx.Err()
default:
msg, err := mappingClient.Recv()
switch {
case errors.Is(err, io.EOF):
// Mapping connection gracefully terminated by server.
return nil
case err != nil:
// Something has gone horribly wrong, return and hope the parent retries the connection.
return fmt.Errorf("receive msg: %w", err)
}
// Process msg updates sequentially to avoid conflict, so block
// additional receiving until this processing is completed.
for _, mapping := range msg.GetMapping() {
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
s.ErrorLog.ErrorContext(ctx, "Error adding new mapping, ignoring this mapping and continuing processing.",
"service_id", mapping.GetId(),
"domain", mapping.GetDomain(),
"error", err)
}
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
s.updateMapping(ctx, mapping)
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
s.removeMapping(mapping)
}
}
}
}
}
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
if err := s.netbird.AddPeer(mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err)
}
if s.acme != nil {
s.acme.AddDomain(mapping.GetDomain())
}
// Pass the mapping through to the update function to avoid duplicating the
// setup, currently update is simply a subset of this function, so this
// separation makes sense...to me at least.
s.updateMapping(ctx, mapping)
return nil
}
func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) {
// Very simple implementation here, we don't touch the existing peer
// connection or any existing TLS configuration, we simply overwrite
// the auth and proxy mappings.
// Note: this does require the management server to always send a
// full mapping rather than deltas during a modification.
var schemes []auth.Scheme
if mapping.GetAuth().GetBasic().GetEnabled() {
schemes = append(schemes, auth.NewBasicAuth(
mapping.GetAuth().GetBasic().GetUsername(),
mapping.GetAuth().GetBasic().GetPassword(),
))
}
if mapping.GetAuth().GetPin().GetEnabled() {
schemes = append(schemes, auth.NewPin(
mapping.GetAuth().GetPin().GetPin(),
))
}
if mapping.GetAuth().GetOidc().GetEnabled() {
oidc := mapping.GetAuth().GetOidc()
scheme, err := auth.NewOIDC(ctx, auth.OIDCConfig{
OIDCProviderURL: oidc.GetOidcProviderUrl(),
OIDCClientID: oidc.GetOidcClientId(),
OIDCClientSecret: oidc.GetOidcClientSecret(),
OIDCRedirectURL: oidc.GetOidcRedirectUrl(),
OIDCScopes: oidc.GetOidcScopes(),
})
if err != nil {
s.ErrorLog.Error("Failed to create OIDC scheme", "error", err)
} else {
schemes = append(schemes, scheme)
}
}
s.auth.AddDomain(mapping.GetDomain(), schemes)
s.proxy.AddMapping(s.protoToMapping(mapping))
}
func (s *Server) removeMapping(mapping *proto.ProxyMapping) {
s.netbird.RemovePeer(mapping.GetDomain())
if s.acme != nil {
s.acme.RemoveDomain(mapping.GetDomain())
}
s.auth.RemoveDomain(mapping.GetDomain())
s.proxy.RemoveMapping(s.protoToMapping(mapping))
}
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
paths := make(map[string]*url.URL)
for _, pathMapping := range mapping.GetPath() {
targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil {
// TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure?
s.ErrorLog.Error("Error parsing target URL for path, this path will be ignored but other paths will still be configured.",
"service_id", mapping.GetId(),
"domain", mapping.GetDomain(),
"path", pathMapping.GetPath(),
"target", pathMapping.GetTarget(),
"error", err)
}
paths[pathMapping.GetPath()] = targetURL
}
return proxy.Mapping{
ID: mapping.GetId(),
Host: mapping.GetDomain(),
Paths: paths,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,27 +9,81 @@ import "google/protobuf/timestamp.proto";
// ProxyService - Management is the SERVER, Proxy is the CLIENT
// Proxy initiates connection to management
service ProxyService {
// Bidirectional stream for proxy-management communication
rpc Stream(stream ProxyMessage) returns (stream ManagementMessage);
rpc GetMappingUpdate(GetMappingUpdateRequest) returns (stream GetMappingUpdateResponse);
rpc SendAccessLog(SendAccessLogRequest) returns (SendAccessLogResponse);
}
// Messages FROM Proxy TO Management
message ProxyMessage {
oneof payload {
ProxyHello hello = 1; // First message on connect
ProxyRequestData request_data = 2; // Real-time access logs
}
}
// Proxy identification on connect
message ProxyHello {
// GetMappingUpdateRequest is sent to initialise a mapping stream.
message GetMappingUpdateRequest {
string proxy_id = 1;
string version = 2;
google.protobuf.Timestamp started_at = 3;
}
// Access log from proxy to management
message ProxyRequestData {
// GetMappingUpdateResponse contains zero or more ProxyMappings.
// No mappings may be sent to test the liveness of the Proxy.
// Mappings that are sent should be interpreted by the Proxy appropriately.
message GetMappingUpdateResponse {
repeated ProxyMapping mapping = 1;
}
enum ProxyMappingUpdateType {
UPDATE_TYPE_CREATED = 0;
UPDATE_TYPE_MODIFIED = 1;
UPDATE_TYPE_REMOVED = 2;
}
message PathMapping {
string path = 1;
string target = 2;
}
message Authentication {
HTTPBasic basic = 1;
Pin pin = 2;
OIDC oidc = 3;
}
message HTTPBasic {
bool enabled = 1;
string username = 2;
string password = 3;
}
message Pin {
bool enabled = 1;
string pin = 2;
}
message OIDC {
bool enabled = 1;
string oidc_provider_url = 2;
string oidc_client_id = 3;
string oidc_client_secret = 4;
string oidc_redirect_url = 5;
repeated string oidc_scopes = 6;
string session_cookie_name = 7;
}
message ProxyMapping {
ProxyMappingUpdateType type = 1;
string id = 2;
string domain = 3;
repeated PathMapping path = 4;
string setup_key = 5;
Authentication auth = 6;
}
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
message SendAccessLogRequest {
AccessLog log = 1;
}
// SendAccessLogResponse is intentionally empty to allow for future expansion.
message SendAccessLogResponse {}
message AccessLog {
google.protobuf.Timestamp timestamp = 1;
string service_id = 2;
string host = 3;
@@ -42,61 +96,3 @@ message ProxyRequestData {
string user_id = 10;
bool auth_success = 11;
}
// Messages FROM Management TO Proxy
message ManagementMessage {
oneof payload {
ServicesSnapshot snapshot = 1; // Full snapshot on initial connect
ServiceUpdate update = 2; // Incremental service update
}
}
// Full snapshot of all services for this proxy
message ServicesSnapshot {
repeated ExposedServiceConfig services = 1;
google.protobuf.Timestamp timestamp = 2;
}
// Incremental service update
message ServiceUpdate {
enum UpdateType {
CREATED = 0;
UPDATED = 1;
REMOVED = 2;
}
UpdateType type = 1;
ExposedServiceConfig service = 2; // Set for CREATED and UPDATED
string service_id = 3; // Service ID (always set)
}
// Exposed service configuration
message ExposedServiceConfig {
string id = 1;
string domain = 2;
map<string, string> path_mappings = 3; // path -> target
string setup_key = 4;
AuthConfig auth = 5;
}
// Authentication configuration
message AuthConfig {
oneof auth_type {
BasicAuthConfig basic_auth = 1;
PinAuthConfig pin_auth = 2;
BearerAuthConfig bearer_auth = 3;
}
}
message BasicAuthConfig {
string username = 1;
string password = 2;
}
message PinAuthConfig {
string pin = 1;
string header = 2;
}
message BearerAuthConfig {
bool enabled = 1;
}

View File

@@ -18,8 +18,8 @@ const _ = grpc.SupportPackageIsVersion7
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ProxyServiceClient interface {
// Bidirectional stream for proxy-management communication
Stream(ctx context.Context, opts ...grpc.CallOption) (ProxyService_StreamClient, error)
GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error)
SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error)
}
type proxyServiceClient struct {
@@ -30,43 +30,53 @@ func NewProxyServiceClient(cc grpc.ClientConnInterface) ProxyServiceClient {
return &proxyServiceClient{cc}
}
func (c *proxyServiceClient) Stream(ctx context.Context, opts ...grpc.CallOption) (ProxyService_StreamClient, error) {
stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[0], "/management.ProxyService/Stream", opts...)
func (c *proxyServiceClient) GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error) {
stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[0], "/management.ProxyService/GetMappingUpdate", opts...)
if err != nil {
return nil, err
}
x := &proxyServiceStreamClient{stream}
x := &proxyServiceGetMappingUpdateClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type ProxyService_StreamClient interface {
Send(*ProxyMessage) error
Recv() (*ManagementMessage, error)
type ProxyService_GetMappingUpdateClient interface {
Recv() (*GetMappingUpdateResponse, error)
grpc.ClientStream
}
type proxyServiceStreamClient struct {
type proxyServiceGetMappingUpdateClient struct {
grpc.ClientStream
}
func (x *proxyServiceStreamClient) Send(m *ProxyMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *proxyServiceStreamClient) Recv() (*ManagementMessage, error) {
m := new(ManagementMessage)
func (x *proxyServiceGetMappingUpdateClient) Recv() (*GetMappingUpdateResponse, error) {
m := new(GetMappingUpdateResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *proxyServiceClient) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) {
out := new(SendAccessLogResponse)
err := c.cc.Invoke(ctx, "/management.ProxyService/SendAccessLog", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ProxyServiceServer is the server API for ProxyService service.
// All implementations must embed UnimplementedProxyServiceServer
// for forward compatibility
type ProxyServiceServer interface {
// Bidirectional stream for proxy-management communication
Stream(ProxyService_StreamServer) error
GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error
SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error)
mustEmbedUnimplementedProxyServiceServer()
}
@@ -74,8 +84,11 @@ type ProxyServiceServer interface {
type UnimplementedProxyServiceServer struct {
}
func (UnimplementedProxyServiceServer) Stream(ProxyService_StreamServer) error {
return status.Errorf(codes.Unimplemented, "method Stream not implemented")
func (UnimplementedProxyServiceServer) GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error {
return status.Errorf(codes.Unimplemented, "method GetMappingUpdate not implemented")
}
func (UnimplementedProxyServiceServer) SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SendAccessLog not implemented")
}
func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {}
@@ -90,30 +103,43 @@ func RegisterProxyServiceServer(s grpc.ServiceRegistrar, srv ProxyServiceServer)
s.RegisterService(&ProxyService_ServiceDesc, srv)
}
func _ProxyService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ProxyServiceServer).Stream(&proxyServiceStreamServer{stream})
func _ProxyService_GetMappingUpdate_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(GetMappingUpdateRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(ProxyServiceServer).GetMappingUpdate(m, &proxyServiceGetMappingUpdateServer{stream})
}
type ProxyService_StreamServer interface {
Send(*ManagementMessage) error
Recv() (*ProxyMessage, error)
type ProxyService_GetMappingUpdateServer interface {
Send(*GetMappingUpdateResponse) error
grpc.ServerStream
}
type proxyServiceStreamServer struct {
type proxyServiceGetMappingUpdateServer struct {
grpc.ServerStream
}
func (x *proxyServiceStreamServer) Send(m *ManagementMessage) error {
func (x *proxyServiceGetMappingUpdateServer) Send(m *GetMappingUpdateResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *proxyServiceStreamServer) Recv() (*ProxyMessage, error) {
m := new(ProxyMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
func _ProxyService_SendAccessLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SendAccessLogRequest)
if err := dec(in); err != nil {
return nil, err
}
return m, nil
if interceptor == nil {
return srv.(ProxyServiceServer).SendAccessLog(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ProxyService/SendAccessLog",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ProxyServiceServer).SendAccessLog(ctx, req.(*SendAccessLogRequest))
}
return interceptor(ctx, in, info, handler)
}
// ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service.
@@ -122,13 +148,17 @@ func (x *proxyServiceStreamServer) Recv() (*ProxyMessage, error) {
var ProxyService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "management.ProxyService",
HandlerType: (*ProxyServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Methods: []grpc.MethodDesc{
{
MethodName: "SendAccessLog",
Handler: _ProxyService_SendAccessLog_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "Stream",
Handler: _ProxyService_Stream_Handler,
StreamName: "GetMappingUpdate",
Handler: _ProxyService_GetMappingUpdate_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "proxy_service.proto",