mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
refactor layout and structure
This commit is contained in:
1
go.mod
1
go.mod
@@ -40,6 +40,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.24
|
||||
|
||||
2
go.sum
2
go.sum
@@ -103,6 +103,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ=
|
||||
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
|
||||
27
proxy/.gitignore
vendored
27
proxy/.gitignore
vendored
@@ -1,27 +0,0 @@
|
||||
# Binaries
|
||||
bin/
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary
|
||||
*.test
|
||||
|
||||
# Output of go coverage tool
|
||||
*.out
|
||||
|
||||
# Configuration files (keep example)
|
||||
config.json
|
||||
|
||||
# IDE files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
@@ -1,83 +0,0 @@
|
||||
.PHONY: build clean run test help version proto
|
||||
|
||||
# Build variables
|
||||
BINARY_NAME=proxy
|
||||
BUILD_DIR=bin
|
||||
|
||||
# Version variables (can be overridden)
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
BUILD_DATE ?= $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||
|
||||
# Go linker flags for version injection
|
||||
LDFLAGS=-ldflags "-X github.com/netbirdio/netbird/proxy/pkg/version.Version=$(VERSION) \
|
||||
-X github.com/netbirdio/netbird/proxy/pkg/version.Commit=$(COMMIT) \
|
||||
-X github.com/netbirdio/netbird/proxy/pkg/version.BuildDate=$(BUILD_DATE)"
|
||||
|
||||
# Build the binary
|
||||
build:
|
||||
@echo "Building $(BINARY_NAME)..."
|
||||
@echo "Version: $(VERSION)"
|
||||
@echo "Commit: $(COMMIT)"
|
||||
@echo "BuildDate: $(BUILD_DATE)"
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOWORK=off go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) .
|
||||
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)"
|
||||
|
||||
# Show version information
|
||||
version:
|
||||
@echo "Version: $(VERSION)"
|
||||
@echo "Commit: $(COMMIT)"
|
||||
@echo "BuildDate: $(BUILD_DATE)"
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning..."
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@go clean
|
||||
@echo "Clean complete"
|
||||
|
||||
# Run the application (requires NB_PROXY_TARGET_URL to be set)
|
||||
run: build
|
||||
@./$(BUILD_DIR)/$(BINARY_NAME)
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
GOWORK=off go test -v ./...
|
||||
|
||||
# Install dependencies
|
||||
deps:
|
||||
@echo "Installing dependencies..."
|
||||
GOWORK=off go mod download
|
||||
GOWORK=off go mod tidy
|
||||
@echo "Dependencies installed"
|
||||
|
||||
# Format code
|
||||
fmt:
|
||||
@echo "Formatting code..."
|
||||
@go fmt ./...
|
||||
@echo "Format complete"
|
||||
|
||||
# Lint code
|
||||
lint:
|
||||
@echo "Linting code..."
|
||||
@golangci-lint run
|
||||
@echo "Lint complete"
|
||||
|
||||
# Generate protobuf files
|
||||
proto:
|
||||
@echo "Generating protobuf files..."
|
||||
@./scripts/generate-proto.sh
|
||||
|
||||
# Show help
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " build - Build the binary"
|
||||
@echo " clean - Remove build artifacts"
|
||||
@echo " run - Build and run the application"
|
||||
@echo " test - Run tests"
|
||||
@echo " proto - Generate protobuf files"
|
||||
@echo " deps - Install dependencies"
|
||||
@echo " fmt - Format code"
|
||||
@echo " lint - Lint code"
|
||||
@echo " help - Show this help message"
|
||||
207
proxy/README.md
207
proxy/README.md
@@ -1,177 +1,50 @@
|
||||
# Netbird Reverse Proxy
|
||||
|
||||
A lightweight, configurable reverse proxy server with graceful shutdown support.
|
||||
The NetBird Reverse Proxy is a separate service that can act as a public entrypoint to certain resources within a NetBird network.
|
||||
At a high level, the way that it operates is:
|
||||
- Configured routes are communicated from the Management server to the proxy.
|
||||
- For each route the proxy creates a NetBird connection to the NetBird Peer that hosts the resource.
|
||||
- When traffic hits the proxy at the address and path configured for the proxied resource, the NetBird Proxy brings up a relevant authentication method for that resource.
|
||||
- On successful authentication the proxy will forward traffic onwards to the NetBird Peer.
|
||||
|
||||
## Features
|
||||
Proxy Authentication methods supported are:
|
||||
- No authentication
|
||||
- Oauth2/OIDC
|
||||
- Emailed Magic Link
|
||||
- Simple PIN
|
||||
- HTTP Basic Auth Username and Password
|
||||
|
||||
- Simple reverse proxy with customizable headers
|
||||
- Configuration via environment variables or JSON file
|
||||
- Graceful shutdown with configurable timeout
|
||||
- Structured logging with logrus
|
||||
- Configurable timeouts (read, write, idle)
|
||||
- Health monitoring support
|
||||
## Management Connection
|
||||
|
||||
## Building
|
||||
The Proxy communicates with the Management server over a gRPC connection.
|
||||
Proxies act as clients to the Management server, the following RPCs are used:
|
||||
- Server-side streaming for proxied service updates.
|
||||
- Client-side streaming for proxy logs.
|
||||
|
||||
```bash
|
||||
# Build the binary
|
||||
GOWORK=off go build -o bin/proxy ./cmd/proxy
|
||||
## Authentication
|
||||
|
||||
# Or use make if available
|
||||
make build
|
||||
```
|
||||
When a request hits the Proxy, it looks up the permitted authentication methods for the Host domain.
|
||||
If no authentication methods are registered for the Host domain, then no authentication will be applied (for fully public resources).
|
||||
If any authentication methods are registered for the Host domain, then the Proxy will first serve an authentication page allowing the user to select an authentication method (from the permitted methods) and enter the required information for that authentication method.
|
||||
If the user is successfully authenticated, their request will be forwarded through to the Proxy to be proxied to the relevant Peer.
|
||||
Successful authentication does not guarantee a successful forwarding of the request as there may be failures behind the Proxy, such as with Peer connectivity or the underlying resource.
|
||||
|
||||
## TLS
|
||||
|
||||
Due to the authentication provided, the Proxy uses HTTPS for its endpoint, even if the underlying service is HTTP.
|
||||
Certificate generation can either be via ACME (by default, using Let's Encrypt, but alternative ACME providers can be used) or through certificate files.
|
||||
When not using ACME, the proxy server attempts to load a certificate and key from the files `tls.crt` and `tls.key` in a specified certificate directory.
|
||||
When using ACME, the proxy server will store generated certificates in the specified certificate directory.
|
||||
|
||||
## Configuration
|
||||
|
||||
The proxy can be configured using either environment variables or a JSON configuration file. Environment variables take precedence over file-based configuration.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `NB_PROXY_LISTEN_ADDRESS` | Address to listen on | `:8080` |
|
||||
| `NB_PROXY_TARGET_URL` | Target URL to proxy requests to | **(required)** |
|
||||
| `NB_PROXY_READ_TIMEOUT` | Read timeout duration | `30s` |
|
||||
| `NB_PROXY_WRITE_TIMEOUT` | Write timeout duration | `30s` |
|
||||
| `NB_PROXY_IDLE_TIMEOUT` | Idle timeout duration | `60s` |
|
||||
| `NB_PROXY_SHUTDOWN_TIMEOUT` | Graceful shutdown timeout | `10s` |
|
||||
| `NB_PROXY_LOG_LEVEL` | Log level (debug, info, warn, error) | `info` |
|
||||
|
||||
### Configuration File
|
||||
|
||||
Create a JSON configuration file:
|
||||
|
||||
```json
|
||||
{
|
||||
"listen_address": ":8080",
|
||||
"target_url": "http://localhost:3000",
|
||||
"read_timeout": "30s",
|
||||
"write_timeout": "30s",
|
||||
"idle_timeout": "60s",
|
||||
"shutdown_timeout": "10s",
|
||||
"log_level": "info"
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Using Environment Variables
|
||||
|
||||
```bash
|
||||
export NB_PROXY_TARGET_URL=http://localhost:3000
|
||||
export NB_PROXY_LOG_LEVEL=debug
|
||||
./bin/proxy
|
||||
```
|
||||
|
||||
### Using Configuration File
|
||||
|
||||
```bash
|
||||
./bin/proxy -config config.json
|
||||
```
|
||||
|
||||
### Combining Both
|
||||
|
||||
Environment variables override file configuration:
|
||||
|
||||
```bash
|
||||
export NB_PROXY_LOG_LEVEL=debug
|
||||
./bin/proxy -config config.json
|
||||
```
|
||||
|
||||
### Docker Example
|
||||
|
||||
```bash
|
||||
docker run -e NB_PROXY_TARGET_URL=http://backend:3000 \
|
||||
-e NB_PROXY_LISTEN_ADDRESS=:8080 \
|
||||
-p 8080:8080 \
|
||||
netbird-proxy
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The application follows a clean architecture with clear separation of concerns:
|
||||
|
||||
```
|
||||
proxy/
|
||||
├── cmd/
|
||||
│ └── proxy/
|
||||
│ └── main.go # Entry point, CLI handling, signal management
|
||||
├── config.go # Configuration loading and validation
|
||||
├── server.go # Server lifecycle (Start/Stop)
|
||||
├── go.mod # Module dependencies
|
||||
└── README.md
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
- **config.go**: Handles configuration loading from environment variables and files using the `github.com/caarlos0/env/v11` library
|
||||
- **server.go**: Encapsulates the HTTP server and reverse proxy logic with proper lifecycle management
|
||||
- **cmd/proxy/main.go**: Entry point that orchestrates startup, graceful shutdown, and signal handling
|
||||
|
||||
## Graceful Shutdown
|
||||
|
||||
The server handles SIGINT and SIGTERM signals for graceful shutdown:
|
||||
|
||||
1. Signal received (Ctrl+C or kill command)
|
||||
2. Server stops accepting new connections
|
||||
3. Existing connections are allowed to complete within the shutdown timeout
|
||||
4. Server exits cleanly
|
||||
|
||||
Press `Ctrl+C` to trigger graceful shutdown:
|
||||
|
||||
```bash
|
||||
^C2026-01-13 22:40:00 INFO Received signal: interrupt
|
||||
2026-01-13 22:40:00 INFO Shutting down server gracefully...
|
||||
2026-01-13 22:40:00 INFO Server stopped successfully
|
||||
2026-01-13 22:40:00 INFO Server exited successfully
|
||||
```
|
||||
|
||||
## Headers
|
||||
|
||||
The proxy automatically sets the following headers on proxied requests:
|
||||
|
||||
- `X-Forwarded-Host`: Original request host
|
||||
- `X-Origin-Host`: Target backend host
|
||||
- `X-Real-IP`: Client's remote address
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Invalid backend connections return `502 Bad Gateway`
|
||||
- All proxy errors are logged with details
|
||||
- Configuration errors are reported at startup
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Go 1.25 or higher
|
||||
- Access to `github.com/sirupsen/logrus`
|
||||
- Access to `github.com/caarlos0/env/v11`
|
||||
|
||||
### Testing Locally
|
||||
|
||||
Start a test backend:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start a simple backend
|
||||
python3 -m http.server 3000
|
||||
```
|
||||
|
||||
Start the proxy:
|
||||
|
||||
```bash
|
||||
# Terminal 2: Start the proxy
|
||||
export NB_PROXY_TARGET_URL=http://localhost:3000
|
||||
./bin/proxy
|
||||
```
|
||||
|
||||
Test the proxy:
|
||||
|
||||
```bash
|
||||
# Terminal 3: Make requests
|
||||
curl http://localhost:8080
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Part of the Netbird project.
|
||||
NetBird Proxy deployment configuration is via flags or environment variables, with flags taking precedence over the environment.
|
||||
The following deployment configuration is available:
|
||||
| Flag | Env | Purpose | Default |
|
||||
+------+-----+---------+---------+
|
||||
| `-mgmt` | `NB_PROXY_MANAGEMENT_ADDRESS` | The address of the management server for the proxy to get configuration from. | `"https://api.netbird.io:443"` |
|
||||
| `-addr` | `NB_PROXY_ADDRESS` | The address that the reverse proxy will listen on. | `":443` |
|
||||
| `-cert-dir` | `NB_PROXY_CERTIFICATE_DIRECTORY` | The location that certficates are stored in. | `"./certs"` |
|
||||
| `-acme-certs` | `NB_PROXY_ACME_CERTIFICATES` | Whether to use ACME to generate certificates. | `false` |
|
||||
| `-acme-addr` | `NB_PROXY_ACME_ADDRESS` | The HTTP address the proxy will listen on to respond to HTTP-01 ACME challenges | `":80"` |
|
||||
| `-acme-dir` | `NB_PROXY_ACME_DIRECTORY` | The directory URL of the ACME server to be used | `"https://acme-v02.api.letsencrypt.org/directory"` |
|
||||
|
||||
84
proxy/cmd/proxy/main.go
Normal file
84
proxy/cmd/proxy/main.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
"golang.org/x/crypto/acme"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
var (
|
||||
// Version is the application version (set via ldflags during build)
|
||||
Version = "dev"
|
||||
|
||||
// Commit is the git commit hash (set via ldflags during build)
|
||||
Commit = "unknown"
|
||||
|
||||
// BuildDate is the build date (set via ldflags during build)
|
||||
BuildDate = "unknown"
|
||||
|
||||
// GoVersion is the Go version used to build the binary
|
||||
GoVersion = runtime.Version()
|
||||
)
|
||||
|
||||
func envBoolOrDefault(key string, def bool) bool {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
return v == strings.ToLower("true")
|
||||
}
|
||||
|
||||
func envStringOrDefault(key string, def string) string {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
version, acmeCerts bool
|
||||
mgmtAddr, addr, certDir, acmeAddr, acmeDir string
|
||||
)
|
||||
|
||||
flag.BoolVar(&version, "v", false, "Print version and exit")
|
||||
flag.StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to.")
|
||||
flag.StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on.")
|
||||
flag.StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store ")
|
||||
flag.BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges.")
|
||||
flag.StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address to listen on, used for ACME HTTP-01 certificate generation.")
|
||||
flag.StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory.")
|
||||
flag.Parse()
|
||||
|
||||
if version {
|
||||
fmt.Printf("Version: %s, Commit: %s, BuildDate: %s, Go: %s", Version, Commit, BuildDate, GoVersion)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Write error logs to stderr.
|
||||
errorLog := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
srv := proxy.Server{
|
||||
ErrorLog: errorLog,
|
||||
ManagementAddress: mgmtAddr,
|
||||
CertificateDirectory: certDir,
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
ACMEDirectory: acmeDir,
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/pkg/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/pkg/version"
|
||||
)
|
||||
|
||||
var (
|
||||
configFile string
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "proxy",
|
||||
Short: "Netbird Reverse Proxy Server",
|
||||
Long: "A lightweight, configurable reverse proxy server.",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
RunE: run,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "path to JSON configuration file (optional, can use env vars instead)")
|
||||
|
||||
rootCmd.Version = version.Short()
|
||||
rootCmd.SetVersionTemplate("{{.Version}}\n")
|
||||
}
|
||||
|
||||
// Execute runs the root command
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) error {
|
||||
config, err := proxy.LoadFromFileOrEnv(configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load configuration: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
setupLogging(config.LogLevel)
|
||||
|
||||
log.Infof("Starting Netbird Proxy - %s", version.Short())
|
||||
log.Debugf("Full version info: %s", version.String())
|
||||
log.Info("Configuration loaded successfully")
|
||||
log.Infof("Listen Address: %s", config.ReverseProxy.ListenAddress)
|
||||
log.Infof("Log Level: %s", config.LogLevel)
|
||||
|
||||
server, err := proxy.NewServer(config)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create server: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
serverErrors := make(chan error, 1)
|
||||
go func() {
|
||||
if err := server.Start(); err != nil {
|
||||
serverErrors <- err
|
||||
}
|
||||
}()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case err := <-serverErrors:
|
||||
log.Fatalf("Server error: %v", err)
|
||||
return err
|
||||
case sig := <-quit:
|
||||
log.Infof("Received signal: %v", sig)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Stop(ctx); err != nil {
|
||||
log.Fatalf("Failed to stop server gracefully: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Server exited successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupLogging(level string) {
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
|
||||
switch level {
|
||||
case "debug":
|
||||
log.SetLevel(log.DebugLevel)
|
||||
case "info":
|
||||
log.SetLevel(log.InfoLevel)
|
||||
case "warn":
|
||||
log.SetLevel(log.WarnLevel)
|
||||
case "error":
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
default:
|
||||
log.SetLevel(log.InfoLevel)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"read_timeout": "30s",
|
||||
"write_timeout": "30s",
|
||||
"idle_timeout": "60s",
|
||||
"shutdown_timeout": "10s",
|
||||
"log_level": "info",
|
||||
"grpc_listen_address": ":50051",
|
||||
"proxy_id": "proxy-1",
|
||||
"enable_grpc": true,
|
||||
"reverse_proxy": {
|
||||
"listen_address": ":443",
|
||||
"management_url": "https://api.netbird.io",
|
||||
"http_listen_address": ":80",
|
||||
"cert_mode": "letsencrypt",
|
||||
"tls_email": "your-email@example.com",
|
||||
"cert_cache_dir": "./certs",
|
||||
"oidc_config": {
|
||||
"provider_url": "https://your-oidc-provider.com",
|
||||
"client_id": "your-client-id",
|
||||
"client_secret": "your-client-secret-if-needed",
|
||||
"redirect_url": "http://localhost:80/auth/callback",
|
||||
"scopes": ["openid", "profile", "email"],
|
||||
"jwt_keys_location": "https://your-oidc-provider.com/.well-known/jwks.json",
|
||||
"jwt_issuer": "https://your-oidc-provider.com/",
|
||||
"jwt_audience": ["your-api-identifier-or-client-id"],
|
||||
"jwt_idp_signkey_refresh_enabled": false,
|
||||
"session_cookie_name": "auth_session"
|
||||
}
|
||||
}
|
||||
}
|
||||
22
proxy/go.mod
22
proxy/go.mod
@@ -1,22 +0,0 @@
|
||||
module github.com/netbirdio/netbird/proxy
|
||||
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/caarlos0/env/v11 v11.3.1
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
golang.org/x/crypto v0.44.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect
|
||||
)
|
||||
65
proxy/go.sum
65
proxy/go.sum
@@ -1,65 +0,0 @@
|
||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
89
proxy/internal/accesslog/logger.go
Normal file
89
proxy/internal/accesslog/logger.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type gRPCClient interface {
|
||||
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
|
||||
}
|
||||
|
||||
type errorLogger interface {
|
||||
ErrorContext(ctx context.Context, msg string, args ...any)
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
client gRPCClient
|
||||
errorLog errorLogger
|
||||
}
|
||||
|
||||
func NewLogger(client gRPCClient, errorLog errorLogger) *Logger {
|
||||
if errorLog == nil {
|
||||
errorLog = slog.New(slog.DiscardHandler)
|
||||
}
|
||||
return &Logger{
|
||||
client: client,
|
||||
errorLog: errorLog,
|
||||
}
|
||||
}
|
||||
|
||||
type logEntry struct {
|
||||
ServiceId string
|
||||
Host string
|
||||
Path string
|
||||
DurationMs int64
|
||||
Method string
|
||||
ResponseCode int32
|
||||
SourceIp string
|
||||
AuthMechanism string
|
||||
UserId string
|
||||
AuthSuccess bool
|
||||
}
|
||||
|
||||
func (l *Logger) log(ctx context.Context, log logEntry) {
|
||||
// Fire off the log request in a separate routine.
|
||||
// This increases the possibility of losing a log message
|
||||
// (although it should still get logged in the event of an error),
|
||||
// but it will reduce latency returning the request in the
|
||||
// middleware.
|
||||
// There is also a chance that log messages will arrive at
|
||||
// the server out of order; however, the timestamp should
|
||||
// allow for resolving that on the server.
|
||||
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
|
||||
go func() {
|
||||
if _, err := l.client.SendAccessLog(ctx, &proto.SendAccessLogRequest{
|
||||
Log: &proto.AccessLog{
|
||||
Timestamp: now,
|
||||
ServiceId: log.ServiceId,
|
||||
Host: log.Host,
|
||||
Path: log.Path,
|
||||
DurationMs: log.DurationMs,
|
||||
Method: log.Method,
|
||||
ResponseCode: log.ResponseCode,
|
||||
SourceIp: log.SourceIp,
|
||||
AuthMechanism: log.AuthMechanism,
|
||||
UserId: log.UserId,
|
||||
AuthSuccess: log.AuthSuccess,
|
||||
},
|
||||
}); err != nil {
|
||||
// If it fails to send on the gRPC connection, then at least log it to the error log.
|
||||
l.errorLog.ErrorContext(ctx, "Error sending access log on gRPC connection",
|
||||
"service_id", log.ServiceId,
|
||||
"host", log.Host,
|
||||
"path", log.Path,
|
||||
"duration", log.DurationMs,
|
||||
"method", log.Method,
|
||||
"response_code", log.ResponseCode,
|
||||
"source_ip", log.SourceIp,
|
||||
"auth_mechanism", log.AuthMechanism,
|
||||
"user_id", log.UserId,
|
||||
"auth_success", log.AuthSuccess,
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
47
proxy/internal/accesslog/middleware.go
Normal file
47
proxy/internal/accesslog/middleware.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
)
|
||||
|
||||
func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
w: w,
|
||||
status: http.StatusOK, // Default status is OK unless otherwise modified.
|
||||
}
|
||||
|
||||
// Get the source IP before passing the request on as the proxy will modify
|
||||
// headers that we wish to use to gather that information on the request.
|
||||
sourceIp := extractSourceIP(r)
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(sw, r)
|
||||
duration := time.Since(start)
|
||||
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// Fallback to just using the full host value.
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
l.log(r.Context(), logEntry{
|
||||
ServiceId: proxy.ServiceIdFromContext(r.Context()),
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(sw.status),
|
||||
SourceIp: sourceIp,
|
||||
AuthMechanism: auth.MethodFromContext(r.Context()).String(),
|
||||
UserId: auth.UserFromContext(r.Context()),
|
||||
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
||||
})
|
||||
})
|
||||
}
|
||||
43
proxy/internal/accesslog/requestip.go
Normal file
43
proxy/internal/accesslog/requestip.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// requestIP attempts to extract the source IP from a request.
|
||||
// Adapted from https://husobee.github.io/golang/ip-address/2015/12/17/remote-ip-go.html
|
||||
// with the addition of some newer stdlib functions that are now
|
||||
// available.
|
||||
// The concept here is to look backwards through IP headers until
|
||||
// the first public IP address is found. The hypothesis is that
|
||||
// even if there are multiple IP addresses specified in these headers,
|
||||
// the last public IP should be the hop immediately before reaching
|
||||
// the server and therefore represents the "true" source IP regardless
|
||||
// of the number of intermediate proxies or network hops.
|
||||
func extractSourceIP(r *http.Request) string {
|
||||
for _, h := range []string{"X-Forwarded-For", "X-Real-IP"} {
|
||||
addresses := strings.Split(r.Header.Get(h), ",")
|
||||
// Iterate from right to left until we get a public address
|
||||
// that should be the address right before our proxy.
|
||||
for _, address := range slices.Backward(addresses) {
|
||||
// Trim the address because sometimes clients put whitespace in there.
|
||||
ip := strings.TrimSpace(address)
|
||||
// Parse the IP so that we can easily check whether it is a valid public address.
|
||||
realIP := net.ParseIP(ip)
|
||||
if !realIP.IsGlobalUnicast() || realIP.IsPrivate() || realIP.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
return ip
|
||||
}
|
||||
}
|
||||
// Fallback to the requests RemoteAddr, this is least likely to be correct but
|
||||
// should at least yield something in the event that the above has failed.
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
26
proxy/internal/accesslog/statuswriter.go
Normal file
26
proxy/internal/accesslog/statuswriter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
type statusWriter struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.w.WriteHeader(status)
|
||||
}
|
||||
53
proxy/internal/acme/manager.go
Normal file
53
proxy/internal/acme/manager.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
*autocert.Manager
|
||||
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]struct{}
|
||||
}
|
||||
|
||||
func NewManager(certDir, acmeURL string) *Manager {
|
||||
mgr := &Manager{
|
||||
domains: make(map[string]struct{}),
|
||||
}
|
||||
mgr.Manager = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: mgr.hostPolicy,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
Client: &acme.Client{
|
||||
DirectoryURL: acmeURL,
|
||||
},
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *Manager) hostPolicy(_ context.Context, domain string) error {
|
||||
mgr.domainsMux.RLock()
|
||||
defer mgr.domainsMux.RUnlock()
|
||||
if _, exists := mgr.domains[domain]; exists {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown domain %q", domain)
|
||||
}
|
||||
|
||||
func (mgr *Manager) AddDomain(domain string) {
|
||||
mgr.domainsMux.Lock()
|
||||
defer mgr.domainsMux.Unlock()
|
||||
mgr.domains[domain] = struct{}{}
|
||||
}
|
||||
|
||||
func (mgr *Manager) RemoveDomain(domain string) {
|
||||
mgr.domainsMux.Lock()
|
||||
defer mgr.domainsMux.Unlock()
|
||||
delete(mgr.domains, domain)
|
||||
}
|
||||
4
proxy/internal/auth/auth.gohtml
Normal file
4
proxy/internal/auth/auth.gohtml
Normal file
@@ -0,0 +1,4 @@
|
||||
<!doctype html>
|
||||
{{ range . }}
|
||||
<p>{{ . }}</p>
|
||||
{{ end }}
|
||||
42
proxy/internal/auth/basicauth.go
Normal file
42
proxy/internal/auth/basicauth.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type BasicAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
func NewBasicAuth(username string, password string) BasicAuth {
|
||||
return BasicAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
||||
|
||||
func (BasicAuth) Type() Method {
|
||||
return MethodBasicAuth
|
||||
}
|
||||
|
||||
func (b BasicAuth) Authenticate(r *http.Request) (string, bool, any) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(b.username)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(b.password)) == 1
|
||||
|
||||
// If authenticated, then return the username.
|
||||
if usernameMatch && passwordMatch {
|
||||
return username, false, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (b BasicAuth) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package auth
|
||||
|
||||
import "github.com/netbirdio/netbird/proxy/internal/auth/methods"
|
||||
|
||||
// Config holds the authentication configuration for a route
|
||||
// Only ONE auth method should be configured per route
|
||||
type Config struct {
|
||||
// HTTP Basic authentication (username/password)
|
||||
BasicAuth *methods.BasicAuthConfig
|
||||
|
||||
// PIN authentication
|
||||
PIN *methods.PINConfig
|
||||
|
||||
// Bearer token with JWT validation and OAuth/OIDC flow
|
||||
// When enabled, uses the global OIDCConfig from proxy Config
|
||||
Bearer *methods.BearerConfig
|
||||
}
|
||||
|
||||
// IsEmpty returns true if no auth methods are configured
|
||||
func (c *Config) IsEmpty() bool {
|
||||
if c == nil {
|
||||
return true
|
||||
}
|
||||
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
|
||||
}
|
||||
38
proxy/internal/auth/context.go
Normal file
38
proxy/internal/auth/context.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
authMethodKey requestContextKey = "authMethod"
|
||||
authUserKey requestContextKey = "authUser"
|
||||
)
|
||||
|
||||
func withAuthMethod(ctx context.Context, method Method) context.Context {
|
||||
return context.WithValue(ctx, authMethodKey, method)
|
||||
}
|
||||
|
||||
func MethodFromContext(ctx context.Context) Method {
|
||||
v := ctx.Value(authMethodKey)
|
||||
method, ok := v.(Method)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
func withAuthUser(ctx context.Context, userId string) context.Context {
|
||||
return context.WithValue(ctx, authUserKey, userId)
|
||||
}
|
||||
|
||||
func UserFromContext(ctx context.Context) string {
|
||||
v := ctx.Value(authUserKey)
|
||||
userId, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return userId
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package methods
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BasicAuthConfig holds HTTP Basic authentication settings
|
||||
type BasicAuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// Validate checks Basic Auth credentials from the request
|
||||
func (c *BasicAuthConfig) Validate(r *http.Request) bool {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
|
||||
|
||||
return usernameMatch && passwordMatch
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package methods
|
||||
|
||||
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
|
||||
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
|
||||
// This just enables Bearer auth for a specific route
|
||||
type BearerConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package methods
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPINHeader is the default header name for PIN authentication
|
||||
DefaultPINHeader = "X-PIN"
|
||||
)
|
||||
|
||||
// PINConfig holds PIN authentication settings
|
||||
type PINConfig struct {
|
||||
PIN string
|
||||
Header string
|
||||
}
|
||||
|
||||
// Validate checks PIN from the request header
|
||||
func (c *PINConfig) Validate(r *http.Request) bool {
|
||||
header := c.Header
|
||||
if header == "" {
|
||||
header = DefaultPINHeader
|
||||
}
|
||||
|
||||
providedPIN := r.Header.Get(header)
|
||||
if providedPIN == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
|
||||
}
|
||||
@@ -1,298 +1,198 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Middleware wraps an HTTP handler with authentication middleware
|
||||
//go:embed auth.gohtml
|
||||
var authTemplate string
|
||||
|
||||
type Method string
|
||||
|
||||
var (
|
||||
MethodBasicAuth Method = "basic"
|
||||
MethodPIN Method = "pin"
|
||||
MethodBearer Method = "bearer"
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sessionCookieName = "nb_session"
|
||||
sessionExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type session struct {
|
||||
UserID string
|
||||
Method Method
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Scheme interface {
|
||||
Type() Method
|
||||
// Authenticate should check the passed request and determine whether
|
||||
// it represents an authenticated user request. If it does not, then
|
||||
// an empty string should indicate an unauthenticated request which
|
||||
// will be rejected; optionally, it can also return any data that should
|
||||
// be included in a UI template when prompting the user to authenticate.
|
||||
// If the request is authenticated, then a user id should be returned
|
||||
// along with a boolean indicating whether a redirect is needed to clean
|
||||
// up authentication artifacts from the URLs query.
|
||||
Authenticate(*http.Request) (userid string, needsRedirect bool, promptData any)
|
||||
// Middleware is applied within the outer auth middleware, but they will
|
||||
// be applied after authentication if no scheme has authenticated a
|
||||
// request.
|
||||
// If no scheme Middleware blocks the request processing, then the auth
|
||||
// middleware will then present the user with the auth UI.
|
||||
Middleware(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
next http.Handler
|
||||
config *Config
|
||||
routeID string
|
||||
rejectResponse func(w http.ResponseWriter, r *http.Request)
|
||||
oidcHandler *oidc.Handler // OIDC handler for OAuth flow (contains config and JWT validator)
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string][]Scheme
|
||||
sessionsMux sync.RWMutex
|
||||
sessions map[string]*session
|
||||
}
|
||||
|
||||
// authResult holds the result of an authentication attempt
|
||||
type authResult struct {
|
||||
authenticated bool
|
||||
method string
|
||||
userID string
|
||||
func NewMiddleware() *Middleware {
|
||||
mw := &Middleware{
|
||||
domains: make(map[string][]Scheme),
|
||||
sessions: make(map[string]*session),
|
||||
}
|
||||
// TODO: goroutine is leaked here.
|
||||
go mw.cleanupSessions()
|
||||
return mw
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface
|
||||
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if m.config.IsEmpty() {
|
||||
m.allowWithoutAuth(w, r)
|
||||
return
|
||||
}
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
// For each incoming request it will be checked against the middleware's
|
||||
// internal list of protected domains.
|
||||
// If the Host domain in the inbound request is not present, then it will
|
||||
// simply be passed through.
|
||||
// However, if the Host domain is present, then the specified authentication
|
||||
// schemes for that domain will be applied to the request.
|
||||
// In the event that no authentication schemes are defined for the domain,
|
||||
// then the request will also be simply passed through.
|
||||
func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
tmpl := template.Must(template.New("auth").Parse(authTemplate))
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mw.domainsMux.RLock()
|
||||
schemes, exists := mw.domains[r.Host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
result := m.authenticate(w, r)
|
||||
if result == nil {
|
||||
// Authentication triggered a redirect (e.g., OIDC flow)
|
||||
return
|
||||
}
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.authenticated {
|
||||
m.rejectRequest(w, r)
|
||||
return
|
||||
}
|
||||
// Check for an existing session to avoid users having to authenticate for every request.
|
||||
// TODO: This does not work if you are load balancing across multiple proxy servers.
|
||||
if cookie, err := r.Cookie(sessionCookieName); err == nil {
|
||||
mw.sessionsMux.RLock()
|
||||
sess, ok := mw.sessions[cookie.Value]
|
||||
mw.sessionsMux.RUnlock()
|
||||
if ok {
|
||||
ctx := withAuthMethod(r.Context(), sess.Method)
|
||||
ctx = withAuthUser(ctx, sess.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.continueWithAuth(w, r, result)
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[Method]any)
|
||||
for _, s := range schemes {
|
||||
userid, needsRedirect, promptData := s.Authenticate(r)
|
||||
if userid != "" {
|
||||
mw.createSession(w, r, userid, s.Type())
|
||||
if needsRedirect {
|
||||
// Clean the path and redirect to the naked URL.
|
||||
// This is intended to prevent leaking potentially
|
||||
// sensitive query parameters for some authentication
|
||||
// methods such as OIDC.
|
||||
http.Redirect(w, r, r.URL.Path, http.StatusFound)
|
||||
return
|
||||
}
|
||||
ctx := withAuthMethod(r.Context(), s.Type())
|
||||
ctx = withAuthUser(ctx, userid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
methods[s.Type()] = promptData
|
||||
}
|
||||
|
||||
// The handler is passed through the scheme middlewares,
|
||||
// if none of them intercept the request, then this handler will
|
||||
// be called and present the user with the authentication page.
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := tmpl.Execute(w, methods); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
}))
|
||||
|
||||
// No authentication succeeded. Apply the scheme handlers.
|
||||
for _, s := range schemes {
|
||||
handler = s.Middleware(handler)
|
||||
}
|
||||
|
||||
// Run the unauthenticated request against the scheme handlers and the final UI handler.
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// allowWithoutAuth allows requests when no authentication is configured
|
||||
func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"auth_method": "none",
|
||||
"path": r.URL.Path,
|
||||
}).Debug("No authentication configured, allowing request")
|
||||
r.Header.Set("X-Auth-Method", "none")
|
||||
m.next.ServeHTTP(w, r)
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = schemes
|
||||
}
|
||||
|
||||
// authenticate attempts to authenticate the request using configured methods
|
||||
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
|
||||
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
|
||||
if result := m.tryBasicAuth(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
if result := m.tryPINAuth(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
return m.tryBearerAuth(w, r)
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// tryBasicAuth attempts Basic authentication
|
||||
func (m *Middleware) tryBasicAuth(r *http.Request) *authResult {
|
||||
if m.config.BasicAuth == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
|
||||
// Generate a random sessionID
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
sessionID := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
if !m.config.BasicAuth.Validate(r) {
|
||||
return &authResult{}
|
||||
mw.sessionsMux.Lock()
|
||||
mw.sessions[sessionID] = &session{
|
||||
UserID: userID,
|
||||
Method: method,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
|
||||
result := &authResult{
|
||||
authenticated: true,
|
||||
method: "basic",
|
||||
}
|
||||
|
||||
if username, _, ok := r.BasicAuth(); ok {
|
||||
result.userID = username
|
||||
}
|
||||
|
||||
return result
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionID,
|
||||
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
|
||||
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
|
||||
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
|
||||
})
|
||||
}
|
||||
|
||||
// tryPINAuth attempts PIN authentication
|
||||
func (m *Middleware) tryPINAuth(r *http.Request) *authResult {
|
||||
if m.config.PIN == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
if !m.config.PIN.Validate(r) {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "pin",
|
||||
userID: "pin_user",
|
||||
}
|
||||
}
|
||||
|
||||
// tryBearerAuth attempts Bearer token authentication with JWT validation
|
||||
// Returns nil if OIDC redirect occurred
|
||||
func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *authResult {
|
||||
if m.config.Bearer == nil || m.oidcHandler == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
cookieName := m.oidcHandler.SessionCookieName()
|
||||
|
||||
if m.handleAuthTokenParameter(w, r, cookieName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result := m.trySessionCookie(r, cookieName); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
if result := m.tryAuthorizationHeader(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
|
||||
// Returns true if a redirect occurred
|
||||
func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Request, cookieName string) bool {
|
||||
authToken := r.URL.Query().Get("_auth_token")
|
||||
if authToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"host": r.Host,
|
||||
}).Info("Found auth token in query parameter, setting cookie and redirecting")
|
||||
|
||||
if !m.oidcHandler.ValidateJWT(authToken) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
}).Warn("Invalid token in query parameter")
|
||||
return false
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: authToken,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
// Redirect to same URL without the token parameter
|
||||
redirectURL := m.buildCleanRedirectURL(r)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"redirect_url": redirectURL,
|
||||
}).Debug("Redirecting to clean URL after setting cookie")
|
||||
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
return true
|
||||
}
|
||||
|
||||
// buildCleanRedirectURL builds a redirect URL without the _auth_token parameter
|
||||
func (m *Middleware) buildCleanRedirectURL(r *http.Request) string {
|
||||
cleanURL := *r.URL
|
||||
q := cleanURL.Query()
|
||||
q.Del("_auth_token")
|
||||
cleanURL.RawQuery = q.Encode()
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
|
||||
}
|
||||
|
||||
// trySessionCookie attempts authentication using a session cookie
|
||||
func (m *Middleware) trySessionCookie(r *http.Request, cookieName string) *authResult {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"cookie_name": cookieName,
|
||||
"host": r.Host,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Checking for session cookie")
|
||||
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"error": err,
|
||||
}).Debug("No session cookie found")
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"cookie_name": cookieName,
|
||||
}).Debug("Session cookie found, validating JWT")
|
||||
|
||||
if !m.oidcHandler.ValidateJWT(cookie.Value) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
}).Debug("JWT validation failed for session cookie")
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "bearer_session",
|
||||
userID: m.oidcHandler.ExtractUserID(cookie.Value),
|
||||
}
|
||||
}
|
||||
|
||||
// tryAuthorizationHeader attempts authentication using the Authorization header
|
||||
func (m *Middleware) tryAuthorizationHeader(r *http.Request) *authResult {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if !m.oidcHandler.ValidateJWT(token) {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "bearer",
|
||||
userID: m.oidcHandler.ExtractUserID(token),
|
||||
}
|
||||
}
|
||||
|
||||
// rejectRequest rejects an unauthenticated request
|
||||
func (m *Middleware) rejectRequest(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"path": r.URL.Path,
|
||||
}).Warn("Authentication failed")
|
||||
|
||||
if m.rejectResponse != nil {
|
||||
m.rejectResponse(w, r)
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// continueWithAuth continues the request with authenticated user info
|
||||
func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, result *authResult) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"auth_method": result.method,
|
||||
"user_id": result.userID,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Authentication successful")
|
||||
|
||||
// TODO: Find other means of auth logging than headers
|
||||
r.Header.Set("X-Auth-Method", result.method)
|
||||
r.Header.Set("X-Auth-User-ID", result.userID)
|
||||
|
||||
// Continue to next handler
|
||||
m.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Wrap wraps an HTTP handler with authentication middleware
|
||||
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
|
||||
if authConfig == nil {
|
||||
authConfig = &Config{}
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
next: next,
|
||||
config: authConfig,
|
||||
routeID: routeID,
|
||||
rejectResponse: rejectResponse,
|
||||
oidcHandler: oidcHandler,
|
||||
func (mw *Middleware) cleanupSessions() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-sessionExpiration)
|
||||
mw.sessionsMux.Lock()
|
||||
for id, sess := range mw.sessions {
|
||||
if sess.CreatedAt.Before(cutoff) {
|
||||
delete(mw.sessions, id)
|
||||
}
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
202
proxy/internal/auth/oidc.go
Normal file
202
proxy/internal/auth/oidc.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const stateExpiration = 10 * time.Minute
|
||||
|
||||
// OIDCConfig holds configuration for OIDC authentication
|
||||
type OIDCConfig struct {
|
||||
OIDCProviderURL string
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
OIDCRedirectURL string
|
||||
OIDCScopes []string
|
||||
}
|
||||
|
||||
// oidcState stores CSRF state with expiration
|
||||
type oidcState struct {
|
||||
OriginalURL string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OIDC implements the Scheme interface for JWT/OIDC authentication
|
||||
type OIDC struct {
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauthConfig *oauth2.Config
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(ctx context.Context, cfg OIDCConfig) (*OIDC, error) {
|
||||
if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" {
|
||||
return nil, fmt.Errorf("OIDC provider URL and client ID are required")
|
||||
}
|
||||
|
||||
scopes := cfg.OIDCScopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||
}
|
||||
|
||||
o := &OIDC{
|
||||
verifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
}),
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
ClientSecret: cfg.OIDCClientSecret,
|
||||
RedirectURL: cfg.OIDCRedirectURL,
|
||||
Scopes: scopes,
|
||||
Endpoint: provider.Endpoint(),
|
||||
},
|
||||
states: make(map[string]*oidcState),
|
||||
}
|
||||
|
||||
go o.cleanupStates()
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
return MethodBearer
|
||||
}
|
||||
|
||||
func (o *OIDC) Authenticate(r *http.Request) (string, bool, any) {
|
||||
// Try Authorization: Bearer <token> header
|
||||
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
||||
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
|
||||
return userID, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try _auth_token query parameter (from OIDC callback redirect)
|
||||
if token := r.URL.Query().Get("_auth_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, true, nil // Redirect needed to clean up URL
|
||||
}
|
||||
}
|
||||
|
||||
// If the request is not authenticated, return a redirect URL for the UI to
|
||||
// route the user through if they select OIDC login.
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
// TODO: this does not work if you are load balancing across multiple proxy servers.
|
||||
o.statesMux.Lock()
|
||||
o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
return "", false, o.oauthConfig.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// Middleware returns an http.Handler that handles OIDC callback and flow initiation.
|
||||
func (o *OIDC) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OIDC callback
|
||||
if r.URL.Path == "/oauth/callback" {
|
||||
o.handleCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// validateToken validates a JWT ID token and returns the user ID (subject)
|
||||
func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
if o.verifier == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
idToken, err := o.verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return idToken.Subject
|
||||
}
|
||||
|
||||
// handleCallback processes the OIDC callback
|
||||
func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify and consume state
|
||||
o.statesMux.Lock()
|
||||
st, ok := o.states[state]
|
||||
if ok {
|
||||
delete(o.states, state)
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
if !ok {
|
||||
http.Error(w, "Invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := o.oauthConfig.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
slog.Error("Token exchange failed", "error", err)
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Prefer ID token if available
|
||||
idToken := token.AccessToken
|
||||
if id, ok := token.Extra("id_token").(string); ok && id != "" {
|
||||
idToken = id
|
||||
}
|
||||
|
||||
// Redirect back to original URL with token
|
||||
origURL, err := url.Parse(st.OriginalURL)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", idToken)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// cleanupStates periodically removes expired states
|
||||
func (o *OIDC) cleanupStates() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-stateExpiration)
|
||||
o.statesMux.Lock()
|
||||
for k, v := range o.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(o.states, k)
|
||||
}
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package oidc
|
||||
|
||||
// Config holds the global OIDC/OAuth configuration
|
||||
type Config struct {
|
||||
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"`
|
||||
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"`
|
||||
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"`
|
||||
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"`
|
||||
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"`
|
||||
|
||||
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"`
|
||||
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"`
|
||||
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"`
|
||||
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"`
|
||||
|
||||
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"`
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Handler manages OIDC authentication flow
|
||||
type Handler struct {
|
||||
config *Config
|
||||
stateStore *StateStore
|
||||
jwtValidator *jwt.Validator
|
||||
}
|
||||
|
||||
// NewHandler creates a new OIDC handler
|
||||
func NewHandler(config *Config, stateStore *StateStore) *Handler {
|
||||
var jwtValidator *jwt.Validator
|
||||
if config.JWTKeysLocation != "" {
|
||||
jwtValidator = jwt.NewValidator(
|
||||
config.JWTIssuer,
|
||||
config.JWTAudience,
|
||||
config.JWTKeysLocation,
|
||||
config.JWTIdpSignkeyRefreshEnabled,
|
||||
)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
config: config,
|
||||
stateStore: stateStore,
|
||||
jwtValidator: jwtValidator,
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectToProvider initiates the OAuth/OIDC authentication flow by redirecting to the provider
|
||||
func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, routeID string) {
|
||||
// Generate random state for CSRF protection
|
||||
state, err := generateRandomString(32)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to generate OIDC state")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store state with original URL for redirect after auth
|
||||
// Include the full URL with scheme and host
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
|
||||
h.stateStore.Store(state, originalURL, routeID)
|
||||
|
||||
// Default scopes if not configured
|
||||
scopes := h.config.Scopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
authURL, err := url.Parse(h.config.ProviderURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Invalid OIDC provider URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Append /authorize if it doesn't exist (common OIDC endpoint)
|
||||
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
|
||||
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.config.ClientID)
|
||||
params.Set("redirect_uri", h.config.RedirectURL)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
params.Set("state", state)
|
||||
|
||||
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
|
||||
params.Set("audience", h.config.JWTAudience[0])
|
||||
}
|
||||
|
||||
authURL.RawQuery = params.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeID,
|
||||
"provider_url": authURL.String(),
|
||||
"redirect_url": h.config.RedirectURL,
|
||||
"state": state,
|
||||
}).Info("Redirecting to OIDC provider for authentication")
|
||||
|
||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleCallback creates an HTTP handler for the OIDC callback endpoint
|
||||
func (h *Handler) HandleCallback() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get authorization code and state from query parameters
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
log.Error("Missing code or state in OIDC callback")
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify state to prevent CSRF
|
||||
oidcSt, ok := h.stateStore.Get(state)
|
||||
if !ok {
|
||||
log.Error("Invalid or expired OIDC state")
|
||||
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete state to prevent reuse
|
||||
h.stateStore.Delete(state)
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := h.exchangeCodeForToken(code)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to exchange code for token")
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the original URL to add the token as a query parameter
|
||||
origURL, err := url.Parse(oidcSt.OriginalURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse original URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add token as query parameter so the original domain can set its own cookie
|
||||
// We use a special parameter name that the auth middleware will look for
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", token)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": oidcSt.RouteID,
|
||||
"original_url": oidcSt.OriginalURL,
|
||||
"redirect_url": origURL.String(),
|
||||
"callback_host": r.Host,
|
||||
}).Info("OIDC authentication successful, redirecting with token parameter")
|
||||
|
||||
// Redirect back to original URL with token parameter
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for an access token
|
||||
func (h *Handler) exchangeCodeForToken(code string) (string, error) {
|
||||
// Build token endpoint URL
|
||||
tokenURL, err := url.Parse(h.config.ProviderURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
|
||||
}
|
||||
|
||||
// Auth0 uses /oauth/token, standard OIDC uses /token
|
||||
// Check if path already contains token endpoint
|
||||
if !strings.Contains(tokenURL.Path, "/token") {
|
||||
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
|
||||
}
|
||||
|
||||
// Build request body
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", h.config.RedirectURL)
|
||||
data.Set("client_id", h.config.ClientID)
|
||||
|
||||
// Only include client_secret if it's provided (not needed for public/SPA clients)
|
||||
if h.config.ClientSecret != "" {
|
||||
data.Set("client_secret", h.config.ClientSecret)
|
||||
}
|
||||
|
||||
// Make token exchange request
|
||||
resp, err := http.PostForm(tokenURL.String(), data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return "", fmt.Errorf("no access token in response")
|
||||
}
|
||||
|
||||
// Return the ID token if available (contains user claims), otherwise access token
|
||||
if tokenResp.IDToken != "" {
|
||||
return tokenResp.IDToken, nil
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
// ValidateJWT validates a JWT token
|
||||
func (h *Handler) ValidateJWT(tokenString string) bool {
|
||||
if h.jwtValidator == nil {
|
||||
log.Error("JWT validation failed: JWT validator not initialized")
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("JWT validation failed")
|
||||
// Try to parse token without validation to see what's in it
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) == 3 {
|
||||
// Decode payload (middle part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"payload": string(payload),
|
||||
}).Debug("Token payload for debugging")
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Token is valid if parsedToken is not nil and Valid is true
|
||||
return parsedToken != nil && parsedToken.Valid
|
||||
}
|
||||
|
||||
// ExtractUserID extracts the user ID from a JWT token
|
||||
func (h *Handler) ExtractUserID(tokenString string) string {
|
||||
if h.jwtValidator == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// parsedToken is already *jwtgo.Token from ValidateAndParse
|
||||
// Create extractor to get user auth info
|
||||
extractor := jwt.NewClaimsExtractor()
|
||||
userAuth, err := extractor.ToUserAuth(parsedToken)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Failed to extract user ID from JWT")
|
||||
return ""
|
||||
}
|
||||
|
||||
return userAuth.UserId
|
||||
}
|
||||
|
||||
// SessionCookieName returns the configured session cookie name or default
|
||||
func (h *Handler) SessionCookieName() string {
|
||||
if h.config.SessionCookieName != "" {
|
||||
return h.config.SessionCookieName
|
||||
}
|
||||
return "auth_session"
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import "time"
|
||||
|
||||
// State represents stored OIDC state information for CSRF protection
|
||||
type State struct {
|
||||
OriginalURL string
|
||||
CreatedAt time.Time
|
||||
RouteID string
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// StateExpiration is how long OIDC state tokens are valid
|
||||
StateExpiration = 10 * time.Minute
|
||||
)
|
||||
|
||||
// StateStore manages OIDC state tokens for CSRF protection
|
||||
type StateStore struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*State
|
||||
}
|
||||
|
||||
// NewStateStore creates a new OIDC state store
|
||||
func NewStateStore() *StateStore {
|
||||
return &StateStore{
|
||||
states: make(map[string]*State),
|
||||
}
|
||||
}
|
||||
|
||||
// Store saves a state token with associated metadata
|
||||
func (s *StateStore) Store(stateToken, originalURL, routeID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.states[stateToken] = &State{
|
||||
OriginalURL: originalURL,
|
||||
CreatedAt: time.Now(),
|
||||
RouteID: routeID,
|
||||
}
|
||||
|
||||
s.cleanup()
|
||||
}
|
||||
|
||||
// Get retrieves a state by token
|
||||
func (s *StateStore) Get(stateToken string) (*State, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
st, ok := s.states[stateToken]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
// Delete removes a state token
|
||||
func (s *StateStore) Delete(stateToken string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.states, stateToken)
|
||||
}
|
||||
|
||||
// cleanup removes expired state tokens (must be called with lock held)
|
||||
func (s *StateStore) cleanup() {
|
||||
cutoff := time.Now().Add(-StateExpiration)
|
||||
for k, v := range s.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(s.states, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// generateRandomString generates a cryptographically secure random string of the specified length
|
||||
func generateRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
45
proxy/internal/auth/pin.go
Normal file
45
proxy/internal/auth/pin.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
userId = "pin-user"
|
||||
formId = "pin"
|
||||
)
|
||||
|
||||
type Pin struct {
|
||||
pin string
|
||||
}
|
||||
|
||||
func NewPin(pin string) Pin {
|
||||
return Pin{
|
||||
pin: pin,
|
||||
}
|
||||
}
|
||||
|
||||
func (Pin) Type() Method {
|
||||
return MethodPIN
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
// value passed in the request.
|
||||
// If authentication fails, the required HTTP form ID is returned
|
||||
// so that it can be injected into a request from the UI so that
|
||||
// authentication may be successful.
|
||||
func (p Pin) Authenticate(r *http.Request) (string, bool, any) {
|
||||
pin := r.FormValue(formId)
|
||||
|
||||
// Compare the passed pin with the expected pin.
|
||||
if subtle.ConstantTimeCompare([]byte(pin), []byte(p.pin)) == 1 {
|
||||
return userId, false, nil
|
||||
}
|
||||
|
||||
return "", false, formId
|
||||
}
|
||||
|
||||
func (p Pin) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
24
proxy/internal/proxy/context.go
Normal file
24
proxy/internal/proxy/context.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
serviceIdKey requestContextKey = "serviceId"
|
||||
)
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId string) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) string {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
44
proxy/internal/proxy/reverseproxy.go
Normal file
44
proxy/internal/proxy/reverseproxy.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
transport http.RoundTripper
|
||||
mappingsMux sync.RWMutex
|
||||
mappings map[string]Mapping
|
||||
}
|
||||
|
||||
// NewReverseProxy configures a new NetBird ReverseProxy.
|
||||
// This is a wrapper around an httputil.ReverseProxy set
|
||||
// to dynamically route requests based on internal mapping
|
||||
// between requested URLs and targets.
|
||||
// The internal mappings can be modified using the AddMapping
|
||||
// and RemoveMapping functions.
|
||||
func NewReverseProxy(transport http.RoundTripper) *ReverseProxy {
|
||||
return &ReverseProxy{
|
||||
transport: transport,
|
||||
mappings: make(map[string]Mapping),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
target, serviceId, exists := p.findTargetForRequest(r)
|
||||
if !exists {
|
||||
// No mapping found so return an error here.
|
||||
// TODO: prettier error page.
|
||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), serviceId)
|
||||
|
||||
// Set up a reverse proxy using the transport and then use it to serve the request.
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
proxy.Transport = p.transport
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
62
proxy/internal/proxy/servicemapping.go
Normal file
62
proxy/internal/proxy/servicemapping.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
if p.mappings == nil {
|
||||
p.mappingsMux.RUnlock()
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
p.mappings = make(map[string]Mapping)
|
||||
// There cannot be any loaded Mappings as we have only just initialized.
|
||||
return nil, "", false
|
||||
}
|
||||
defer p.mappingsMux.RUnlock()
|
||||
m, exists := p.mappings[req.Host]
|
||||
if !exists {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
|
||||
paths := make([]string, 0, len(m.Paths))
|
||||
for path := range m.Paths {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
return m.Paths[path], m.ID, true
|
||||
}
|
||||
}
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
if p.mappings == nil {
|
||||
p.mappings = make(map[string]Mapping)
|
||||
}
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
delete(p.mappings, m.Host)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package certmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Manager defines the interface for certificate management
|
||||
type Manager interface {
|
||||
// IsEnabled returns whether certificate management is enabled
|
||||
IsEnabled() bool
|
||||
|
||||
// AddDomain adds a domain to the allowed hosts list
|
||||
AddDomain(domain string)
|
||||
|
||||
// RemoveDomain removes a domain from the allowed hosts list
|
||||
RemoveDomain(domain string)
|
||||
|
||||
// IssueCertificate eagerly issues a certificate for a domain
|
||||
IssueCertificate(ctx context.Context, domain string) error
|
||||
|
||||
// TLSConfig returns the TLS configuration for the HTTPS server
|
||||
TLSConfig() *tls.Config
|
||||
|
||||
// HTTPHandler returns the HTTP handler for ACME challenges (or fallback)
|
||||
HTTPHandler(fallback http.Handler) http.Handler
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package certmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// LetsEncryptManager handles TLS certificate issuance via Let's Encrypt
|
||||
type LetsEncryptManager struct {
|
||||
autocertManager *autocert.Manager
|
||||
allowedHosts map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// LetsEncryptConfig holds Let's Encrypt certificate manager configuration
|
||||
type LetsEncryptConfig struct {
|
||||
// Email for Let's Encrypt registration (required)
|
||||
Email string
|
||||
|
||||
// CertCacheDir is the directory to cache certificates
|
||||
CertCacheDir string
|
||||
}
|
||||
|
||||
// NewLetsEncrypt creates a new Let's Encrypt certificate manager
|
||||
func NewLetsEncrypt(config LetsEncryptConfig) *LetsEncryptManager {
|
||||
m := &LetsEncryptManager{
|
||||
allowedHosts: make(map[string]bool),
|
||||
}
|
||||
|
||||
m.autocertManager = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: m.hostPolicy,
|
||||
Cache: autocert.DirCache(config.CertCacheDir),
|
||||
Email: config.Email,
|
||||
RenewBefore: 0, // Use default 30 days prior to expiration
|
||||
}
|
||||
|
||||
log.Info("Let's Encrypt certificate manager initialized")
|
||||
return m
|
||||
}
|
||||
|
||||
// IsEnabled returns whether certificate management is enabled
|
||||
func (m *LetsEncryptManager) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AddDomain adds a domain to the allowed hosts list
|
||||
func (m *LetsEncryptManager) AddDomain(domain string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.allowedHosts[domain] = true
|
||||
log.Infof("Added domain to Let's Encrypt manager: %s", domain)
|
||||
}
|
||||
|
||||
// RemoveDomain removes a domain from the allowed hosts list
|
||||
func (m *LetsEncryptManager) RemoveDomain(domain string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.allowedHosts, domain)
|
||||
log.Infof("Removed domain from Let's Encrypt manager: %s", domain)
|
||||
}
|
||||
|
||||
// IssueCertificate eagerly issues a Let's Encrypt certificate for a domain
|
||||
func (m *LetsEncryptManager) IssueCertificate(ctx context.Context, domain string) error {
|
||||
log.Infof("Issuing Let's Encrypt certificate for domain: %s", domain)
|
||||
|
||||
hello := &tls.ClientHelloInfo{
|
||||
ServerName: domain,
|
||||
}
|
||||
|
||||
cert, err := m.autocertManager.GetCertificate(hello)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to issue certificate for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
log.Infof("Successfully issued Let's Encrypt certificate for domain: %s (expires: %s)",
|
||||
domain, cert.Leaf.NotAfter.Format(time.RFC3339))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLSConfig returns the TLS configuration for the HTTPS server
|
||||
func (m *LetsEncryptManager) TLSConfig() *tls.Config {
|
||||
return m.autocertManager.TLSConfig()
|
||||
}
|
||||
|
||||
// HTTPHandler returns the HTTP handler for ACME challenges
|
||||
func (m *LetsEncryptManager) HTTPHandler(fallback http.Handler) http.Handler {
|
||||
return m.autocertManager.HTTPHandler(fallback)
|
||||
}
|
||||
|
||||
// hostPolicy validates that a requested host is in the allowed hosts list
|
||||
func (m *LetsEncryptManager) hostPolicy(ctx context.Context, host string) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.allowedHosts[host] {
|
||||
log.Debugf("ACME challenge accepted for domain: %s", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warnf("ACME challenge rejected for unconfigured domain: %s", host)
|
||||
return fmt.Errorf("host %s not configured", host)
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
package certmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SelfSignedManager handles self-signed certificate generation for local testing
|
||||
type SelfSignedManager struct {
|
||||
certificates map[string]*tls.Certificate // domain -> certificate cache
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSelfSigned creates a new self-signed certificate manager
|
||||
func NewSelfSigned() *SelfSignedManager {
|
||||
log.Info("Self-signed certificate manager initialized")
|
||||
return &SelfSignedManager{
|
||||
certificates: make(map[string]*tls.Certificate),
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether certificate management is enabled
|
||||
func (m *SelfSignedManager) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AddDomain adds a domain to the manager (no-op for self-signed, but maintains interface)
|
||||
func (m *SelfSignedManager) AddDomain(domain string) {
|
||||
log.Infof("Added domain to self-signed manager: %s", domain)
|
||||
}
|
||||
|
||||
// RemoveDomain removes a domain from the manager
|
||||
func (m *SelfSignedManager) RemoveDomain(domain string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.certificates, domain)
|
||||
log.Infof("Removed domain from self-signed manager: %s", domain)
|
||||
}
|
||||
|
||||
// IssueCertificate generates and caches a self-signed certificate for a domain
|
||||
func (m *SelfSignedManager) IssueCertificate(ctx context.Context, domain string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.certificates[domain]; exists {
|
||||
log.Debugf("Self-signed certificate already exists for domain: %s", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := m.generateCertificate(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.certificates[domain] = cert
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLSConfig returns the TLS configuration for the HTTPS server
|
||||
func (m *SelfSignedManager) TLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: m.getCertificate,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPHandler returns the fallback handler (no ACME challenges for self-signed)
|
||||
func (m *SelfSignedManager) HTTPHandler(fallback http.Handler) http.Handler {
|
||||
return fallback
|
||||
}
|
||||
|
||||
// getCertificate returns a self-signed certificate for the requested domain
|
||||
func (m *SelfSignedManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
m.mu.RLock()
|
||||
cert, exists := m.certificates[hello.ServerName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
log.Infof("Generating self-signed certificate on-demand for: %s", hello.ServerName)
|
||||
|
||||
newCert, err := m.generateCertificate(hello.ServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.certificates[hello.ServerName] = newCert
|
||||
m.mu.Unlock()
|
||||
|
||||
return newCert, nil
|
||||
}
|
||||
|
||||
// generateCertificate generates a self-signed certificate for a domain
|
||||
func (m *SelfSignedManager) generateCertificate(domain string) (*tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"NetBird Local Development"},
|
||||
CommonName: domain,
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{domain},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{certDER},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}
|
||||
|
||||
log.Infof("Generated self-signed certificate for domain: %s (expires: %s)",
|
||||
domain, cert.NotAfter.Format(time.RFC3339))
|
||||
|
||||
return tlsCert, nil
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
|
||||
)
|
||||
|
||||
// Config holds the reverse proxy configuration
|
||||
type Config struct {
|
||||
// ListenAddress is the address to listen on for HTTPS (default ":443")
|
||||
ListenAddress string `env:"NB_REVERSE_PROXY_LISTEN_ADDRESS" envDefault:":443" json:"listen_address"`
|
||||
|
||||
// ManagementURL is the URL of the management server
|
||||
ManagementURL string `env:"NB_REVERSE_PROXY_MANAGEMENT_URL" json:"management_url"`
|
||||
|
||||
// HTTPListenAddress is the address for HTTP (default ":80")
|
||||
// Used for ACME challenges (Let's Encrypt HTTP-01 challenge)
|
||||
HTTPListenAddress string `env:"NB_REVERSE_PROXY_HTTP_LISTEN_ADDRESS" envDefault:":80" json:"http_listen_address"`
|
||||
|
||||
// CertMode specifies certificate mode: "letsencrypt" or "selfsigned" (default: "letsencrypt")
|
||||
// "letsencrypt" - Uses Let's Encrypt for production certificates (requires public domain)
|
||||
// "selfsigned" - Generates self-signed certificates for local testing
|
||||
CertMode string `env:"NB_REVERSE_PROXY_CERT_MODE" envDefault:"letsencrypt" json:"cert_mode"`
|
||||
|
||||
// TLSEmail is the email for Let's Encrypt registration (required for letsencrypt mode)
|
||||
TLSEmail string `env:"NB_REVERSE_PROXY_TLS_EMAIL" json:"tls_email"`
|
||||
|
||||
// CertCacheDir is the directory to cache certificates (for letsencrypt mode, default "./certs")
|
||||
CertCacheDir string `env:"NB_REVERSE_PROXY_CERT_CACHE_DIR" envDefault:"./certs" json:"cert_cache_dir"`
|
||||
|
||||
// OIDCConfig is the global OIDC/OAuth configuration for authentication
|
||||
// This is shared across all routes that use Bearer authentication
|
||||
// If nil, routes with Bearer auth will fail to initialize
|
||||
OIDCConfig *oidc.Config `json:"oidc_config"`
|
||||
}
|
||||
|
||||
// RouteConfig defines a routing configuration
|
||||
type RouteConfig struct {
|
||||
// ID is a unique identifier for this route
|
||||
ID string
|
||||
|
||||
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
|
||||
Domain string
|
||||
|
||||
// PathMappings defines paths that should be forwarded to specific ports
|
||||
// Key is the path prefix (e.g., "/", "/api", "/admin")
|
||||
// Value is the target IP:port (e.g., "192.168.1.100:3000")
|
||||
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
|
||||
PathMappings map[string]string
|
||||
|
||||
SetupKey string
|
||||
nbClient *embed.Client
|
||||
|
||||
// AuthConfig is optional authentication configuration for this route
|
||||
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
|
||||
// If nil, requests pass through without authentication
|
||||
AuthConfig *auth.Config
|
||||
|
||||
// AuthRejectResponse is an optional custom response for authentication failures
|
||||
// If nil, returns 401 Unauthorized with WWW-Authenticate header
|
||||
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// routeEntry represents a compiled route with its proxy
|
||||
type routeEntry struct {
|
||||
routeConfig *RouteConfig
|
||||
path string
|
||||
target string
|
||||
proxy *httputil.ReverseProxy
|
||||
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
|
||||
}
|
||||
|
||||
// RequestDataCallback is called for each proxied request with metrics
|
||||
type RequestDataCallback func(data RequestData)
|
||||
|
||||
// RequestData contains metrics for a proxied request
|
||||
type RequestData struct {
|
||||
ServiceID string
|
||||
Host string
|
||||
Path string
|
||||
DurationMs int64
|
||||
Method string
|
||||
ResponseCode int32
|
||||
SourceIP string
|
||||
AuthMechanism string
|
||||
UserID string
|
||||
AuthSuccess bool
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
)
|
||||
|
||||
// buildHandler creates the main HTTP handler with router for static endpoints
|
||||
func (p *Proxy) buildHandler() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Register static endpoints
|
||||
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
|
||||
|
||||
// Catch-all handler for dynamic proxy routing
|
||||
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// handleProxyRequest handles all dynamic proxy requests
|
||||
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
|
||||
routeEntry := p.findRoute(r.Host, r.URL.Path)
|
||||
if routeEntry == nil {
|
||||
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
routeEntry.handler.ServeHTTP(rw, r)
|
||||
|
||||
if p.requestCallback != nil {
|
||||
duration := time.Since(startTime)
|
||||
|
||||
host := r.Host
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// TODO: extract logging data
|
||||
authMechanism := r.Header.Get("X-Auth-Method")
|
||||
if authMechanism == "" {
|
||||
authMechanism = "none"
|
||||
}
|
||||
userID := r.Header.Get("X-Auth-User-ID")
|
||||
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
||||
sourceIP := extractSourceIP(r)
|
||||
|
||||
data := RequestData{
|
||||
ServiceID: routeEntry.routeConfig.ID,
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(rw.statusCode),
|
||||
SourceIP: sourceIP,
|
||||
AuthMechanism: authMechanism,
|
||||
UserID: userID,
|
||||
AuthSuccess: authSuccess,
|
||||
}
|
||||
|
||||
p.requestCallback(data)
|
||||
}
|
||||
}
|
||||
|
||||
// findRoute finds the matching route for a given host and path
|
||||
func (p *Proxy) findRoute(host, path string) *routeEntry {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
routeConfig, exists := p.routes[host]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
var entries []*routeEntry
|
||||
|
||||
for routePath, target := range routeConfig.PathMappings {
|
||||
proxy := p.createProxy(routeConfig, target)
|
||||
|
||||
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
|
||||
|
||||
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
||||
var authType string
|
||||
if routeConfig.AuthConfig.BasicAuth != nil {
|
||||
authType = "basic_auth"
|
||||
} else if routeConfig.AuthConfig.PIN != nil {
|
||||
authType = "pin"
|
||||
} else if routeConfig.AuthConfig.Bearer != nil {
|
||||
authType = "bearer_jwt"
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
"auth_type": authType,
|
||||
}).Debug("Auth middleware enabled for route")
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
}).Debug("No authentication configured for route")
|
||||
}
|
||||
|
||||
entries = append(entries, &routeEntry{
|
||||
routeConfig: routeConfig,
|
||||
path: routePath,
|
||||
target: target,
|
||||
proxy: proxy,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by path specificity (longest first)
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
pi, pj := entries[i].path, entries[j].path
|
||||
// Empty string or "/" goes last (catch-all)
|
||||
if pi == "" || pi == "/" {
|
||||
return false
|
||||
}
|
||||
if pj == "" || pj == "/" {
|
||||
return true
|
||||
}
|
||||
return len(pi) > len(pj)
|
||||
})
|
||||
|
||||
// Find first matching entry
|
||||
for _, entry := range entries {
|
||||
if entry.path == "" || entry.path == "/" {
|
||||
// Catch-all route
|
||||
return entry
|
||||
}
|
||||
if strings.HasPrefix(path, entry.path) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProxy creates a reverse proxy for a target with the route's connection
|
||||
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
||||
targetURL, err := url.Parse("http://" + target)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
proxy.Transport = &http.Transport{
|
||||
DialContext: routeConfig.nbClient.DialContext,
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
IdleConnTimeout: 0,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
||||
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if p.oidcHandler == nil {
|
||||
log.Error("OIDC callback received but no OIDC handler configured")
|
||||
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
handler := p.oidcHandler.HandleCallback()
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
// extractSourceIP extracts the source IP from the request
|
||||
func extractSourceIP(r *http.Request) string {
|
||||
// Try X-Forwarded-For header first
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Try X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/reverseproxy/certmanager"
|
||||
)
|
||||
|
||||
// Proxy wraps a reverse proxy with dynamic routing
|
||||
type Proxy struct {
|
||||
config Config
|
||||
mu sync.RWMutex
|
||||
routes map[string]*RouteConfig // key is host/domain (for fast O(1) lookup)
|
||||
server *http.Server
|
||||
httpServer *http.Server
|
||||
certManager certmanager.Manager
|
||||
isRunning bool
|
||||
requestCallback RequestDataCallback
|
||||
oidcHandler *oidc.Handler
|
||||
}
|
||||
|
||||
// New creates a new reverse proxy
|
||||
func New(config Config) (*Proxy, error) {
|
||||
if config.ListenAddress == "" {
|
||||
config.ListenAddress = ":443"
|
||||
}
|
||||
if config.HTTPListenAddress == "" {
|
||||
config.HTTPListenAddress = ":80"
|
||||
}
|
||||
if config.CertCacheDir == "" {
|
||||
config.CertCacheDir = "./certs"
|
||||
}
|
||||
|
||||
if config.CertMode == "" {
|
||||
config.CertMode = "letsencrypt"
|
||||
}
|
||||
|
||||
if config.CertMode == "letsencrypt" && config.TLSEmail == "" {
|
||||
return nil, fmt.Errorf("TLSEmail is required for letsencrypt mode")
|
||||
}
|
||||
|
||||
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
|
||||
config.OIDCConfig.SessionCookieName = "auth_session"
|
||||
}
|
||||
|
||||
var certMgr certmanager.Manager
|
||||
if config.CertMode == "selfsigned" {
|
||||
// HTTPS with self-signed certificates (for local testing)
|
||||
certMgr = certmanager.NewSelfSigned()
|
||||
} else {
|
||||
// HTTPS with Let's Encrypt (for production)
|
||||
certMgr = certmanager.NewLetsEncrypt(certmanager.LetsEncryptConfig{
|
||||
Email: config.TLSEmail,
|
||||
CertCacheDir: config.CertCacheDir,
|
||||
})
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
config: config,
|
||||
routes: make(map[string]*RouteConfig),
|
||||
certManager: certMgr,
|
||||
isRunning: false,
|
||||
}
|
||||
|
||||
if config.OIDCConfig != nil {
|
||||
stateStore := oidc.NewStateStore()
|
||||
p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Start starts the reverse proxy server (non-blocking)
|
||||
func (p *Proxy) Start() error {
|
||||
p.mu.Lock()
|
||||
if p.isRunning {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy already running")
|
||||
}
|
||||
p.isRunning = true
|
||||
p.mu.Unlock()
|
||||
|
||||
handler := p.buildHandler()
|
||||
|
||||
return p.startHTTPS(handler)
|
||||
}
|
||||
|
||||
// startHTTPS starts the proxy with HTTPS
|
||||
func (p *Proxy) startHTTPS(handler http.Handler) error {
|
||||
p.httpServer = &http.Server{
|
||||
Addr: p.config.HTTPListenAddress,
|
||||
Handler: p.certManager.HTTPHandler(nil),
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
|
||||
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
p.server = &http.Server{
|
||||
Addr: p.config.ListenAddress,
|
||||
Handler: handler,
|
||||
TLSConfig: p.certManager.TLSConfig(),
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting HTTPS reverse proxy server on %s", p.config.ListenAddress)
|
||||
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("HTTPS server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the reverse proxy server
|
||||
func (p *Proxy) Stop(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
if !p.isRunning {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy not running")
|
||||
}
|
||||
p.isRunning = false
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Info("Stopping reverse proxy server...")
|
||||
|
||||
if p.httpServer != nil {
|
||||
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||
log.Errorf("Error shutting down HTTP server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.server != nil {
|
||||
if err := p.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("error shutting down server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Reverse proxy server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the proxy is running
|
||||
func (p *Proxy) IsRunning() bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.isRunning
|
||||
}
|
||||
|
||||
// SetRequestCallback sets the callback for request metrics
|
||||
func (p *Proxy) SetRequestCallback(callback RequestDataCallback) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.requestCallback = callback
|
||||
}
|
||||
|
||||
// GetConfig returns the proxy configuration
|
||||
func (p *Proxy) GetConfig() Config {
|
||||
return p.config
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartupTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// AddRoute adds a new route to the proxy
|
||||
func (p *Proxy) AddRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
if route.Domain == "" {
|
||||
return fmt.Errorf("route Domain is required")
|
||||
}
|
||||
if len(route.PathMappings) == 0 {
|
||||
return fmt.Errorf("route must have at least one path mapping")
|
||||
}
|
||||
if route.SetupKey == "" {
|
||||
return fmt.Errorf("route setup key is required")
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if _, exists := p.routes[route.Domain]; exists {
|
||||
return fmt.Errorf("route for domain %s already exists", route.Domain)
|
||||
}
|
||||
|
||||
client, err := embed.New(embed.Options{DeviceName: fmt.Sprintf("ingress-%s", route.ID), ManagementURL: p.config.ManagementURL, SetupKey: route.SetupKey, LogOutput: io.Discard})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create embedded client for route %s: %v", route.ID, err)
|
||||
}
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), clientStartupTimeout)
|
||||
err = client.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start embedded client for route %s: %v", route.ID, err)
|
||||
}
|
||||
|
||||
route.nbClient = client
|
||||
|
||||
p.routes[route.Domain] = route
|
||||
|
||||
p.certManager.AddDomain(route.Domain)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Added route")
|
||||
|
||||
go func(domain string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := p.certManager.IssueCertificate(ctx, domain); err != nil {
|
||||
log.Errorf("Failed to issue certificate: %v", err)
|
||||
// TODO: Better error feedback mechanism
|
||||
}
|
||||
}(route.Domain)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route by domain
|
||||
func (p *Proxy) RemoveRoute(domain string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if _, exists := p.routes[domain]; !exists {
|
||||
return fmt.Errorf("route for domain %s not found", domain)
|
||||
}
|
||||
|
||||
delete(p.routes, domain)
|
||||
|
||||
p.certManager.RemoveDomain(domain)
|
||||
|
||||
log.Infof("Removed route for domain: %s", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute updates an existing route
|
||||
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
if route.Domain == "" {
|
||||
return fmt.Errorf("route Domain is required")
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if _, exists := p.routes[route.Domain]; !exists {
|
||||
return fmt.Errorf("route for domain %s not found", route.Domain)
|
||||
}
|
||||
|
||||
p.routes[route.Domain] = route
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Updated route")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all configured domains
|
||||
func (p *Proxy) ListRoutes() []string {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
domains := make([]string, 0, len(p.routes))
|
||||
for domain := range p.routes {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// GetRoute returns a route configuration by domain
|
||||
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
route, exists := p.routes[domain]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("route for domain %s not found", domain)
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
61
proxy/internal/roundtrip/netbird.go
Normal file
61
proxy/internal/roundtrip/netbird.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const deviceNamePrefix = "ingress-"
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[string]*http.Client
|
||||
}
|
||||
|
||||
func NewNetBird(mgmtAddr string) *NetBird {
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
clients: make(map[string]*http.Client),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) AddPeer(domain, key string) error {
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + domain,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
SetupKey: key,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.clients[domain] = client.NewHTTPClient()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) RemovePeer(domain string) {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
delete(n.clients, domain)
|
||||
}
|
||||
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[req.Host]
|
||||
// Immediately unlock after retrieval here rather than defer to avoid
|
||||
// the call to client.Do blocking other clients being used whilst one
|
||||
// is in use.
|
||||
n.clientsMux.RUnlock()
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -1,325 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnectInterval = 5 * time.Second
|
||||
proxyVersion = "0.1.0"
|
||||
)
|
||||
|
||||
// ServiceUpdateHandler is called when services are added/updated/removed
|
||||
type ServiceUpdateHandler func(update *proto.ServiceUpdate) error
|
||||
|
||||
// Client manages the gRPC connection to management server
|
||||
type Client struct {
|
||||
proxyID string
|
||||
managementURL string
|
||||
conn *grpc.ClientConn
|
||||
stream proto.ProxyService_StreamClient
|
||||
serviceUpdateHandler ServiceUpdateHandler
|
||||
accessLogChan chan *proto.ProxyRequestData
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
connected bool
|
||||
}
|
||||
|
||||
// ClientConfig holds client configuration
|
||||
type ClientConfig struct {
|
||||
ProxyID string
|
||||
ManagementURL string
|
||||
ServiceUpdateHandler ServiceUpdateHandler
|
||||
}
|
||||
|
||||
// NewClient creates a new gRPC client for proxy-management communication
|
||||
func NewClient(config ClientConfig) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Client{
|
||||
proxyID: config.ProxyID,
|
||||
managementURL: config.ManagementURL,
|
||||
serviceUpdateHandler: config.ServiceUpdateHandler,
|
||||
accessLogChan: make(chan *proto.ProxyRequestData, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start connects to management server and maintains connection
|
||||
func (c *Client) Start() error {
|
||||
go c.connectionLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the connection
|
||||
func (c *Client) Stop() error {
|
||||
c.cancel()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.stream != nil {
|
||||
// Try to close stream gracefully
|
||||
_ = c.stream.CloseSend()
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAccessLog queues an access log to be sent to management
|
||||
func (c *Client) SendAccessLog(log *proto.ProxyRequestData) {
|
||||
select {
|
||||
case c.accessLogChan <- log:
|
||||
default:
|
||||
// Channel full, drop log
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns whether client is connected to management
|
||||
func (c *Client) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.connected
|
||||
}
|
||||
|
||||
// connectionLoop maintains connection to management server
|
||||
func (c *Client) connectionLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
log.Infof("Connecting to management server at %s", c.managementURL)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
log.Errorf("Failed to connect to management: %v", err)
|
||||
c.setConnected(false)
|
||||
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection
|
||||
if err := c.handleConnection(); err != nil {
|
||||
log.Errorf("Connection error: %v", err)
|
||||
c.setConnected(false)
|
||||
}
|
||||
|
||||
// Reconnect after delay
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(reconnectInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connect establishes connection to management server
|
||||
func (c *Client) connect() error {
|
||||
// Strip scheme from URL if present (gRPC doesn't use http:// or https://)
|
||||
target := c.managementURL
|
||||
target = strings.TrimPrefix(target, "http://")
|
||||
target = strings.TrimPrefix(target, "https://")
|
||||
|
||||
// Create gRPC connection
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO: Add TLS
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(target, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
// Create stream
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
stream, err := client.Stream(c.ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to create stream: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.stream = stream
|
||||
c.mu.Unlock()
|
||||
|
||||
// Send ProxyHello
|
||||
hello := &proto.ProxyMessage{
|
||||
Payload: &proto.ProxyMessage_Hello{
|
||||
Hello: &proto.ProxyHello{
|
||||
ProxyId: c.proxyID,
|
||||
Version: proxyVersion,
|
||||
StartedAt: timestamppb.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := stream.Send(hello); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to send hello: %w", err)
|
||||
}
|
||||
|
||||
c.setConnected(true)
|
||||
log.Info("Successfully connected to management server")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleConnection manages the active connection
|
||||
func (c *Client) handleConnection() error {
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
// Start sender goroutine
|
||||
go c.sender(errChan)
|
||||
|
||||
// Start receiver goroutine
|
||||
go c.receiver(errChan)
|
||||
|
||||
// Wait for error
|
||||
return <-errChan
|
||||
}
|
||||
|
||||
// sender sends messages to management
|
||||
func (c *Client) sender(errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
errChan <- c.ctx.Err()
|
||||
return
|
||||
|
||||
case accessLog := <-c.accessLogChan:
|
||||
msg := &proto.ProxyMessage{
|
||||
Payload: &proto.ProxyMessage_RequestData{
|
||||
RequestData: accessLog,
|
||||
},
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
stream := c.stream
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stream == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := stream.Send(msg); err != nil {
|
||||
log.Errorf("Failed to send access log: %v", err)
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// receiver receives messages from management
|
||||
func (c *Client) receiver(errChan chan<- error) {
|
||||
for {
|
||||
c.mu.RLock()
|
||||
stream := c.stream
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stream == nil {
|
||||
errChan <- fmt.Errorf("stream is nil")
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Info("Management server closed connection")
|
||||
errChan <- io.EOF
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("Failed to receive: %v", err)
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Handle message
|
||||
switch payload := msg.GetPayload().(type) {
|
||||
case *proto.ManagementMessage_Snapshot:
|
||||
c.handleSnapshot(payload.Snapshot)
|
||||
case *proto.ManagementMessage_Update:
|
||||
c.handleServiceUpdate(payload.Update)
|
||||
default:
|
||||
log.Warnf("Received unknown message type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleSnapshot processes initial services snapshot
|
||||
func (c *Client) handleSnapshot(snapshot *proto.ServicesSnapshot) {
|
||||
log.Infof("Received services snapshot with %d services", len(snapshot.Services))
|
||||
|
||||
if c.serviceUpdateHandler == nil {
|
||||
log.Warn("No service update handler configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Process each service as a CREATED update
|
||||
for _, service := range snapshot.Services {
|
||||
update := &proto.ServiceUpdate{
|
||||
Type: proto.ServiceUpdate_CREATED,
|
||||
Service: service,
|
||||
ServiceId: service.Id,
|
||||
}
|
||||
|
||||
if err := c.serviceUpdateHandler(update); err != nil {
|
||||
log.Errorf("Failed to handle service %s: %v", service.Id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceUpdate processes incremental service update
|
||||
func (c *Client) handleServiceUpdate(update *proto.ServiceUpdate) {
|
||||
log.Infof("Received service update: %s %s", update.Type, update.ServiceId)
|
||||
|
||||
if c.serviceUpdateHandler == nil {
|
||||
log.Warn("No service update handler configured")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.serviceUpdateHandler(update); err != nil {
|
||||
log.Errorf("Failed to handle service update: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// setConnected updates connected status
|
||||
func (c *Client) setConnected(connected bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.connected = connected
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFailedToParseConfig = errors.New("failed to parse config from env")
|
||||
)
|
||||
|
||||
// Duration is a time.Duration that can be unmarshaled from JSON as a string
|
||||
type Duration time.Duration
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Duration
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsed, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*d = Duration(parsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Duration
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(time.Duration(d).String())
|
||||
}
|
||||
|
||||
// ToDuration converts Duration to time.Duration
|
||||
func (d Duration) ToDuration() time.Duration {
|
||||
return time.Duration(d)
|
||||
}
|
||||
|
||||
// Config holds the configuration for the reverse proxy server
|
||||
type Config struct {
|
||||
// ReadTimeout is the maximum duration for reading the entire request, including the body
|
||||
ReadTimeout time.Duration `env:"NB_PROXY_READ_TIMEOUT" envDefault:"30s" json:"read_timeout"`
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||
WriteTimeout time.Duration `env:"NB_PROXY_WRITE_TIMEOUT" envDefault:"30s" json:"write_timeout"`
|
||||
|
||||
// IdleTimeout is the maximum amount of time to wait for the next request when keep-alives are enabled
|
||||
IdleTimeout time.Duration `env:"NB_PROXY_IDLE_TIMEOUT" envDefault:"60s" json:"idle_timeout"`
|
||||
|
||||
// ShutdownTimeout is the maximum duration to wait for graceful shutdown
|
||||
ShutdownTimeout time.Duration `env:"NB_PROXY_SHUTDOWN_TIMEOUT" envDefault:"10s" json:"shutdown_timeout"`
|
||||
|
||||
// LogLevel sets the logging verbosity (debug, info, warn, error)
|
||||
LogLevel string `env:"NB_PROXY_LOG_LEVEL" envDefault:"info" json:"log_level"`
|
||||
|
||||
// GRPCListenAddress is the address for the gRPC control server
|
||||
GRPCListenAddress string `env:"NB_PROXY_GRPC_LISTEN_ADDRESS" envDefault:":50051" json:"grpc_listen_address"`
|
||||
|
||||
// ProxyID is a unique identifier for this proxy instance
|
||||
ProxyID string `env:"NB_PROXY_ID" envDefault:"" json:"proxy_id"`
|
||||
|
||||
// EnableGRPC enables the gRPC control server
|
||||
EnableGRPC bool `env:"NB_PROXY_ENABLE_GRPC" envDefault:"false" json:"enable_grpc"`
|
||||
|
||||
// Reverse Proxy Configuration
|
||||
ReverseProxy reverseproxy.Config `json:"reverse_proxy"`
|
||||
}
|
||||
|
||||
// ParseAndLoad parses configuration from environment variables
|
||||
func ParseAndLoad() (Config, error) {
|
||||
var cfg Config
|
||||
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
return cfg, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return cfg, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadFromFile reads configuration from a JSON file
|
||||
func LoadFromFile(path string) (Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadFromFileOrEnv loads configuration from a file if path is provided, otherwise from environment variables
|
||||
func LoadFromFileOrEnv(configPath string) (Config, error) {
|
||||
var cfg Config
|
||||
|
||||
// If config file is provided, load it first
|
||||
if configPath != "" {
|
||||
fileCfg, err := LoadFromFile(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to load config from file: %w", err)
|
||||
}
|
||||
cfg = fileCfg
|
||||
} else {
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("%w: %s", ErrFailedToParseConfig, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling with automatic duration parsing
|
||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(c).Elem()
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
fieldType := typ.Field(i)
|
||||
|
||||
jsonTag := fieldType.Tag.Get("json")
|
||||
if jsonTag == "" || jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonFieldName := jsonTag
|
||||
if idx := len(jsonTag); idx > 0 {
|
||||
for j, c := range jsonTag {
|
||||
if c == ',' {
|
||||
jsonFieldName = jsonTag[:j]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rawValue, exists := raw[jsonFieldName]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if field.Type() == reflect.TypeOf(time.Duration(0)) {
|
||||
if strValue, ok := rawValue.(string); ok {
|
||||
duration, err := time.ParseDuration(strValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration for field %s: %w", jsonFieldName, err)
|
||||
}
|
||||
field.Set(reflect.ValueOf(duration))
|
||||
} else {
|
||||
return fmt.Errorf("field %s must be a duration string", jsonFieldName)
|
||||
}
|
||||
} else {
|
||||
fieldData, err := json.Marshal(rawValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal field %s: %w", jsonFieldName, err)
|
||||
}
|
||||
|
||||
if field.CanSet() {
|
||||
newVal := reflect.New(field.Type())
|
||||
if err := json.Unmarshal(fieldData, newVal.Interface()); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal field %s: %w", jsonFieldName, err)
|
||||
}
|
||||
field.Set(newVal.Elem())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *Config) Validate() error {
|
||||
validLogLevels := map[string]bool{
|
||||
"debug": true,
|
||||
"info": true,
|
||||
"warn": true,
|
||||
"error": true,
|
||||
}
|
||||
|
||||
if !validLogLevels[c.LogLevel] {
|
||||
return fmt.Errorf("invalid log_level: %s (must be debug, info, warn, or error)", c.LogLevel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,592 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth/methods"
|
||||
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
|
||||
grpcpkg "github.com/netbirdio/netbird/proxy/pkg/grpc"
|
||||
pb "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Server represents the reverse proxy server with integrated gRPC client
|
||||
type Server struct {
|
||||
config Config
|
||||
grpcClient *grpcpkg.Client
|
||||
proxy *reverseproxy.Proxy
|
||||
|
||||
mu sync.RWMutex
|
||||
isRunning bool
|
||||
grpcRunning bool
|
||||
|
||||
shutdownCtx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
|
||||
// Statistics for gRPC reporting
|
||||
stats *Stats
|
||||
|
||||
// Track exposed services and their peer configs
|
||||
exposedServices map[string]*ExposedServiceConfig
|
||||
}
|
||||
|
||||
// Stats holds proxy statistics
|
||||
type Stats struct {
|
||||
mu sync.RWMutex
|
||||
totalRequests uint64
|
||||
activeConns uint64
|
||||
bytesSent uint64
|
||||
bytesReceived uint64
|
||||
}
|
||||
|
||||
// ExposedServiceConfig holds the configuration for an exposed service
|
||||
type ExposedServiceConfig struct {
|
||||
ServiceID string
|
||||
PeerConfig *PeerConfig
|
||||
UpstreamConfig *UpstreamConfig
|
||||
}
|
||||
|
||||
// PeerConfig holds WireGuard peer configuration
|
||||
type PeerConfig struct {
|
||||
PeerID string
|
||||
PublicKey string
|
||||
AllowedIPs []string
|
||||
Endpoint string
|
||||
TunnelIP string // The WireGuard tunnel IP to route traffic to
|
||||
}
|
||||
|
||||
// UpstreamConfig holds reverse proxy upstream configuration
|
||||
type UpstreamConfig struct {
|
||||
Domain string
|
||||
PathMappings map[string]string // path -> port mapping (relative to tunnel IP)
|
||||
}
|
||||
|
||||
// NewServer creates a new reverse proxy server instance
|
||||
func NewServer(config Config) (*Server, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
shutdownCtx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
server := &Server{
|
||||
config: config,
|
||||
isRunning: false,
|
||||
grpcRunning: false,
|
||||
shutdownCtx: shutdownCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
stats: &Stats{},
|
||||
exposedServices: make(map[string]*ExposedServiceConfig),
|
||||
}
|
||||
|
||||
proxy, err := reverseproxy.New(config.ReverseProxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create reverse proxy: %w", err)
|
||||
}
|
||||
server.proxy = proxy
|
||||
|
||||
if config.ReverseProxy.ManagementURL == "" {
|
||||
return nil, fmt.Errorf("management URL is required")
|
||||
}
|
||||
|
||||
grpcClient := grpcpkg.NewClient(grpcpkg.ClientConfig{
|
||||
ProxyID: config.ProxyID,
|
||||
ManagementURL: config.ReverseProxy.ManagementURL,
|
||||
ServiceUpdateHandler: server.handleServiceUpdate,
|
||||
})
|
||||
server.grpcClient = grpcClient
|
||||
|
||||
// Set request data callback to send access logs to management
|
||||
proxy.SetRequestCallback(func(data reverseproxy.RequestData) {
|
||||
accessLog := &pb.ProxyRequestData{
|
||||
Timestamp: timestamppb.Now(),
|
||||
ServiceId: data.ServiceID,
|
||||
Host: data.Host,
|
||||
Path: data.Path,
|
||||
DurationMs: data.DurationMs,
|
||||
Method: data.Method,
|
||||
ResponseCode: data.ResponseCode,
|
||||
SourceIp: data.SourceIP,
|
||||
AuthMechanism: data.AuthMechanism,
|
||||
UserId: data.UserID,
|
||||
AuthSuccess: data.AuthSuccess,
|
||||
}
|
||||
server.grpcClient.SendAccessLog(accessLog)
|
||||
})
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Start starts the reverse proxy server and optionally the gRPC control server
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
if s.isRunning {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server is already running")
|
||||
}
|
||||
s.isRunning = true
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("Starting proxy reverse proxy server on %s", s.config.ReverseProxy.ListenAddress)
|
||||
|
||||
if err := s.proxy.Start(); err != nil {
|
||||
s.mu.Lock()
|
||||
s.isRunning = false
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to start reverse proxy: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.grpcRunning = true
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.grpcClient.Start(); err != nil {
|
||||
s.mu.Lock()
|
||||
s.isRunning = false
|
||||
s.grpcRunning = false
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to start gRPC client: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Proxy started and connected to management")
|
||||
log.Info("Waiting for service configurations from management...")
|
||||
|
||||
<-s.shutdownCtx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down both proxy and gRPC servers
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
if !s.isRunning {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server is not running")
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Info("Shutting down servers gracefully...")
|
||||
|
||||
// If no context provided, use the server's shutdown timeout
|
||||
if ctx == nil {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(context.Background(), s.config.ShutdownTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var proxyErr, grpcErr error
|
||||
|
||||
// Stop gRPC client first
|
||||
if s.grpcRunning {
|
||||
if err := s.grpcClient.Stop(); err != nil {
|
||||
grpcErr = fmt.Errorf("gRPC client shutdown failed: %w", err)
|
||||
log.Error(grpcErr)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.grpcRunning = false
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Shutdown reverse proxy
|
||||
if err := s.proxy.Stop(ctx); err != nil {
|
||||
proxyErr = fmt.Errorf("reverse proxy shutdown failed: %w", err)
|
||||
log.Error(proxyErr)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.isRunning = false
|
||||
s.mu.Unlock()
|
||||
|
||||
if proxyErr != nil {
|
||||
return proxyErr
|
||||
}
|
||||
if grpcErr != nil {
|
||||
return grpcErr
|
||||
}
|
||||
|
||||
log.Info("All servers stopped successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the server is currently running
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.isRunning
|
||||
}
|
||||
|
||||
// GetConfig returns a copy of the server configuration
|
||||
func (s *Server) GetConfig() Config {
|
||||
return s.config
|
||||
}
|
||||
|
||||
// handleServiceUpdate processes service updates from management
|
||||
func (s *Server) handleServiceUpdate(update *pb.ServiceUpdate) error {
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": update.ServiceId,
|
||||
"type": update.Type.String(),
|
||||
}).Info("Received service update from management")
|
||||
|
||||
switch update.Type {
|
||||
case pb.ServiceUpdate_CREATED:
|
||||
if update.Service == nil {
|
||||
return fmt.Errorf("service config is nil for CREATED update")
|
||||
}
|
||||
return s.addServiceFromProto(update.Service)
|
||||
|
||||
case pb.ServiceUpdate_UPDATED:
|
||||
if update.Service == nil {
|
||||
return fmt.Errorf("service config is nil for UPDATED update")
|
||||
}
|
||||
return s.updateServiceFromProto(update.Service)
|
||||
|
||||
case pb.ServiceUpdate_REMOVED:
|
||||
return s.removeService(update.ServiceId)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown service update type: %v", update.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// addServiceFromProto adds a service from proto config
|
||||
func (s *Server) addServiceFromProto(serviceConfig *pb.ExposedServiceConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if service already exists
|
||||
if _, exists := s.exposedServices[serviceConfig.Id]; exists {
|
||||
log.Warnf("Service %s already exists, updating instead", serviceConfig.Id)
|
||||
return s.updateServiceFromProtoLocked(serviceConfig)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceConfig.Id,
|
||||
"domain": serviceConfig.Domain,
|
||||
}).Info("Adding service from management")
|
||||
|
||||
// Convert proto auth config to internal auth config
|
||||
var authConfig *auth.Config
|
||||
if serviceConfig.Auth != nil {
|
||||
authConfig = convertProtoAuthConfig(serviceConfig.Auth)
|
||||
}
|
||||
|
||||
// Add route to proxy
|
||||
route := &reverseproxy.RouteConfig{
|
||||
ID: serviceConfig.Id,
|
||||
Domain: serviceConfig.Domain,
|
||||
PathMappings: serviceConfig.PathMappings,
|
||||
AuthConfig: authConfig,
|
||||
SetupKey: serviceConfig.SetupKey,
|
||||
}
|
||||
|
||||
if err := s.proxy.AddRoute(route); err != nil {
|
||||
return fmt.Errorf("failed to add route: %w", err)
|
||||
}
|
||||
|
||||
// Store service config (simplified, no peer config for now)
|
||||
s.exposedServices[serviceConfig.Id] = &ExposedServiceConfig{
|
||||
ServiceID: serviceConfig.Id,
|
||||
UpstreamConfig: &UpstreamConfig{
|
||||
Domain: serviceConfig.Domain,
|
||||
PathMappings: serviceConfig.PathMappings,
|
||||
},
|
||||
}
|
||||
|
||||
log.Infof("Service %s added successfully", serviceConfig.Id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateServiceFromProto updates an existing service from proto config
|
||||
func (s *Server) updateServiceFromProto(serviceConfig *pb.ExposedServiceConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.updateServiceFromProtoLocked(serviceConfig)
|
||||
}
|
||||
|
||||
func (s *Server) updateServiceFromProtoLocked(serviceConfig *pb.ExposedServiceConfig) error {
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceConfig.Id,
|
||||
"domain": serviceConfig.Domain,
|
||||
}).Info("Updating service from management")
|
||||
|
||||
// Convert proto auth config to internal auth config
|
||||
var authConfig *auth.Config
|
||||
if serviceConfig.Auth != nil {
|
||||
authConfig = convertProtoAuthConfig(serviceConfig.Auth)
|
||||
}
|
||||
|
||||
// Update route in proxy
|
||||
route := &reverseproxy.RouteConfig{
|
||||
ID: serviceConfig.Id,
|
||||
Domain: serviceConfig.Domain,
|
||||
PathMappings: serviceConfig.PathMappings,
|
||||
AuthConfig: authConfig,
|
||||
SetupKey: serviceConfig.SetupKey,
|
||||
}
|
||||
|
||||
if err := s.proxy.UpdateRoute(route); err != nil {
|
||||
return fmt.Errorf("failed to update route: %w", err)
|
||||
}
|
||||
|
||||
// Update service config
|
||||
s.exposedServices[serviceConfig.Id] = &ExposedServiceConfig{
|
||||
ServiceID: serviceConfig.Id,
|
||||
UpstreamConfig: &UpstreamConfig{
|
||||
Domain: serviceConfig.Domain,
|
||||
PathMappings: serviceConfig.PathMappings,
|
||||
},
|
||||
}
|
||||
|
||||
log.Infof("Service %s updated successfully", serviceConfig.Id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeService removes a service
|
||||
func (s *Server) removeService(serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
}).Info("Removing service from management")
|
||||
|
||||
// Remove route from proxy
|
||||
if err := s.proxy.RemoveRoute(serviceID); err != nil {
|
||||
return fmt.Errorf("failed to remove route: %w", err)
|
||||
}
|
||||
|
||||
// Remove service config
|
||||
delete(s.exposedServices, serviceID)
|
||||
|
||||
log.Infof("Service %s removed successfully", serviceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertProtoAuthConfig converts proto auth config to internal auth config
|
||||
func convertProtoAuthConfig(protoAuth *pb.AuthConfig) *auth.Config {
|
||||
authConfig := &auth.Config{}
|
||||
|
||||
switch authType := protoAuth.AuthType.(type) {
|
||||
case *pb.AuthConfig_BasicAuth:
|
||||
authConfig.BasicAuth = &methods.BasicAuthConfig{
|
||||
Username: authType.BasicAuth.Username,
|
||||
Password: authType.BasicAuth.Password,
|
||||
}
|
||||
case *pb.AuthConfig_PinAuth:
|
||||
authConfig.PIN = &methods.PINConfig{
|
||||
PIN: authType.PinAuth.Pin,
|
||||
Header: authType.PinAuth.Header,
|
||||
}
|
||||
case *pb.AuthConfig_BearerAuth:
|
||||
authConfig.Bearer = &methods.BearerConfig{
|
||||
Enabled: authType.BearerAuth.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
return authConfig
|
||||
}
|
||||
|
||||
// Exposed Service Handlers (deprecated - keeping for backwards compatibility)
|
||||
|
||||
// handleExposedServiceCreated handles the creation of a new exposed service
|
||||
func (s *Server) handleExposedServiceCreated(serviceID string, peerConfig *PeerConfig, upstreamConfig *UpstreamConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if service already exists
|
||||
if _, exists := s.exposedServices[serviceID]; exists {
|
||||
return fmt.Errorf("exposed service %s already exists", serviceID)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
"peer_id": peerConfig.PeerID,
|
||||
"tunnel_ip": peerConfig.TunnelIP,
|
||||
"domain": upstreamConfig.Domain,
|
||||
}).Info("Creating exposed service")
|
||||
|
||||
// TODO: Create WireGuard tunnel for peer
|
||||
// 1. Initialize WireGuard interface if not already done
|
||||
// 2. Add peer configuration:
|
||||
// - Public key: peerConfig.PublicKey
|
||||
// - Endpoint: peerConfig.Endpoint
|
||||
// - Allowed IPs: peerConfig.AllowedIPs
|
||||
// - Persistent keepalive: 25 seconds
|
||||
// 3. Bring up the WireGuard interface
|
||||
// 4. Verify tunnel connectivity to peerConfig.TunnelIP
|
||||
// Example pseudo-code:
|
||||
// wgClient.AddPeer(&wireguard.PeerConfig{
|
||||
// PublicKey: peerConfig.PublicKey,
|
||||
// Endpoint: peerConfig.Endpoint,
|
||||
// AllowedIPs: peerConfig.AllowedIPs,
|
||||
// PersistentKeepalive: 25,
|
||||
// })
|
||||
|
||||
// Build path mappings with tunnel IP
|
||||
pathMappings := make(map[string]string)
|
||||
for path, port := range upstreamConfig.PathMappings {
|
||||
// Combine tunnel IP with port
|
||||
target := fmt.Sprintf("%s:%s", peerConfig.TunnelIP, port)
|
||||
pathMappings[path] = target
|
||||
}
|
||||
|
||||
// Add route to proxy
|
||||
route := &reverseproxy.RouteConfig{
|
||||
ID: serviceID,
|
||||
Domain: upstreamConfig.Domain,
|
||||
PathMappings: pathMappings,
|
||||
}
|
||||
|
||||
if err := s.proxy.AddRoute(route); err != nil {
|
||||
return fmt.Errorf("failed to add route: %w", err)
|
||||
}
|
||||
|
||||
// Store service config
|
||||
s.exposedServices[serviceID] = &ExposedServiceConfig{
|
||||
ServiceID: serviceID,
|
||||
PeerConfig: peerConfig,
|
||||
UpstreamConfig: upstreamConfig,
|
||||
}
|
||||
|
||||
log.Infof("Exposed service %s created successfully", serviceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleExposedServiceUpdated handles updates to an existing exposed service
|
||||
func (s *Server) handleExposedServiceUpdated(serviceID string, peerConfig *PeerConfig, upstreamConfig *UpstreamConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if service exists
|
||||
if _, exists := s.exposedServices[serviceID]; !exists {
|
||||
return fmt.Errorf("exposed service %s not found", serviceID)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
"peer_id": peerConfig.PeerID,
|
||||
"tunnel_ip": peerConfig.TunnelIP,
|
||||
"domain": upstreamConfig.Domain,
|
||||
}).Info("Updating exposed service")
|
||||
|
||||
// TODO: Update WireGuard tunnel if peer config changed
|
||||
|
||||
// Build path mappings with tunnel IP
|
||||
pathMappings := make(map[string]string)
|
||||
for path, port := range upstreamConfig.PathMappings {
|
||||
target := fmt.Sprintf("%s:%s", peerConfig.TunnelIP, port)
|
||||
pathMappings[path] = target
|
||||
}
|
||||
|
||||
// Update route in proxy
|
||||
route := &reverseproxy.RouteConfig{
|
||||
ID: serviceID,
|
||||
Domain: upstreamConfig.Domain,
|
||||
PathMappings: pathMappings,
|
||||
}
|
||||
|
||||
if err := s.proxy.UpdateRoute(route); err != nil {
|
||||
return fmt.Errorf("failed to update route: %w", err)
|
||||
}
|
||||
|
||||
// Update service config
|
||||
s.exposedServices[serviceID] = &ExposedServiceConfig{
|
||||
ServiceID: serviceID,
|
||||
PeerConfig: peerConfig,
|
||||
UpstreamConfig: upstreamConfig,
|
||||
}
|
||||
|
||||
log.Infof("Exposed service %s updated successfully", serviceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleExposedServiceRemoved handles the removal of an exposed service
|
||||
func (s *Server) handleExposedServiceRemoved(serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if service exists
|
||||
if _, exists := s.exposedServices[serviceID]; !exists {
|
||||
return fmt.Errorf("exposed service %s not found", serviceID)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
}).Info("Removing exposed service")
|
||||
|
||||
// Remove route from proxy
|
||||
if err := s.proxy.RemoveRoute(serviceID); err != nil {
|
||||
return fmt.Errorf("failed to remove route: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Remove WireGuard tunnel for peer
|
||||
|
||||
// Remove service config
|
||||
delete(s.exposedServices, serviceID)
|
||||
|
||||
log.Infof("Exposed service %s removed successfully", serviceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListExposedServices returns a list of all exposed service IDs
|
||||
func (s *Server) ListExposedServices() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
services := make([]string, 0, len(s.exposedServices))
|
||||
for id := range s.exposedServices {
|
||||
services = append(services, id)
|
||||
}
|
||||
return services
|
||||
}
|
||||
|
||||
// GetExposedService returns the configuration for a specific exposed service
|
||||
func (s *Server) GetExposedService(serviceID string) (*ExposedServiceConfig, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
service, exists := s.exposedServices[serviceID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("exposed service %s not found", serviceID)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// Stats methods
|
||||
|
||||
func (s *Stats) IncrementRequests() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.totalRequests++
|
||||
}
|
||||
|
||||
func (s *Stats) IncrementActiveConns() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.activeConns++
|
||||
}
|
||||
|
||||
func (s *Stats) DecrementActiveConns() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.activeConns > 0 {
|
||||
s.activeConns--
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stats) AddBytesSent(bytes uint64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.bytesSent += bytes
|
||||
}
|
||||
|
||||
func (s *Stats) AddBytesReceived(bytes uint64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.bytesReceived += bytes
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var (
|
||||
// Version is the application version (set via ldflags during build)
|
||||
Version = "dev"
|
||||
|
||||
// Commit is the git commit hash (set via ldflags during build)
|
||||
Commit = "unknown"
|
||||
|
||||
// BuildDate is the build date (set via ldflags during build)
|
||||
BuildDate = "unknown"
|
||||
|
||||
// GoVersion is the Go version used to build the binary
|
||||
GoVersion = runtime.Version()
|
||||
)
|
||||
|
||||
// Info contains version information
|
||||
type Info struct {
|
||||
Version string `json:"version"`
|
||||
Commit string `json:"commit"`
|
||||
BuildDate string `json:"build_date"`
|
||||
GoVersion string `json:"NewSingleHostReverseProxygo_version"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
}
|
||||
|
||||
// Get returns the version information
|
||||
func Get() Info {
|
||||
return Info{
|
||||
Version: Version,
|
||||
Commit: Commit,
|
||||
BuildDate: BuildDate,
|
||||
GoVersion: GoVersion,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a formatted version string
|
||||
func String() string {
|
||||
return fmt.Sprintf("Version: %s, Commit: %s, BuildDate: %s, Go: %s",
|
||||
Version, Commit, BuildDate, GoVersion)
|
||||
}
|
||||
|
||||
// Short returns a short version string
|
||||
func Short() string {
|
||||
if Version == "dev" {
|
||||
return fmt.Sprintf("%s (%s)", Version, Commit[:7])
|
||||
}
|
||||
return Version
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Check if protoc is installed
|
||||
if ! command -v protoc &> /dev/null; then
|
||||
echo "Error: protoc is not installed"
|
||||
echo "Install with: apt-get install -y protobuf-compiler"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if protoc-gen-go is installed
|
||||
if ! command -v protoc-gen-go &> /dev/null; then
|
||||
echo "Installing protoc-gen-go..."
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
||||
fi
|
||||
|
||||
# Check if protoc-gen-go-grpc is installed
|
||||
if ! command -v protoc-gen-go-grpc &> /dev/null; then
|
||||
echo "Installing protoc-gen-go-grpc..."
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
|
||||
fi
|
||||
|
||||
echo "Generating protobuf files..."
|
||||
|
||||
# Generate Go code from proto files
|
||||
protoc --go_out=. --go_opt=paths=source_relative \
|
||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
pkg/grpc/proto/proxy.proto
|
||||
|
||||
echo "Proto generation complete!"
|
||||
277
proxy/server.go
Normal file
277
proxy/server.go
Normal file
@@ -0,0 +1,277 @@
|
||||
// Package proxy runs a NetBird proxy server.
|
||||
// It attempts to do everything it needs to do within the context
|
||||
// of a single request to the server to try to reduce the amount
|
||||
// of concurrency coordination that is required. However, it does
|
||||
// run two additional routines in an error group for handling
|
||||
// updates from the management server and running a separate
|
||||
// HTTP server to handle ACME HTTP-01 challenges (if configured).
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/backoff"
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type errorLog interface {
|
||||
Error(msg string, args ...any)
|
||||
ErrorContext(ctx context.Context, msg string, args ...any)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
mgmtConn *grpc.ClientConn
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
acme *acme.Manager
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
|
||||
ErrorLog errorLog
|
||||
ManagementAddress string
|
||||
CertificateDirectory string
|
||||
GenerateACMECertificates bool
|
||||
ACMEChallengeAddress string
|
||||
ACMEDirectory string
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
if s.ErrorLog == nil {
|
||||
// If no ErrorLog is specified, then just discard the log output.
|
||||
s.ErrorLog = slog.New(slog.DiscardHandler)
|
||||
}
|
||||
|
||||
// The very first thing to do should be to connect to the Management server.
|
||||
// Without this connection, the Proxy cannot do anything.
|
||||
s.mgmtConn, err = grpc.NewClient(s.ManagementAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil {
|
||||
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
|
||||
for range time.Tick(10 * time.Second) {
|
||||
s.ErrorLog.ErrorContext(ctx, "ACME HTTP-01 challenge server error", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
} else {
|
||||
// Otherwise pull some certificates from expected locations.
|
||||
cert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(s.CertificateDirectory, "tls.crt"),
|
||||
filepath.Join(s.CertificateDirectory, "tls.key"),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load provided certificate: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.netbird)
|
||||
|
||||
// Configure the authentication middleware.
|
||||
s.auth = auth.NewMiddleware()
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(mgmtClient, s.ErrorLog)
|
||||
|
||||
// Finally, start the reverse proxy.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.auth.Protect(accessLog.Middleware(s.proxy)),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
return s.https.ListenAndServeTLS("", "")
|
||||
}
|
||||
|
||||
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) func() {
|
||||
b := backoff.New(0, 0)
|
||||
return func() {
|
||||
for {
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{})
|
||||
if err != nil {
|
||||
backoffDuration := b.Duration()
|
||||
s.ErrorLog.ErrorContext(ctx, "Unable to create mapping client to management server, retrying connection after backoff.",
|
||||
"backoff", backoffDuration,
|
||||
"error", err)
|
||||
time.Sleep(backoffDuration)
|
||||
continue
|
||||
}
|
||||
err = s.handleMappingStream(ctx, mappingClient)
|
||||
backoffDuration := b.Duration()
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled),
|
||||
errors.Is(err, context.DeadlineExceeded):
|
||||
// Context is telling us that it is time to quit so gracefully exit here.
|
||||
// No need to log the error as it is a parent context causing this return.
|
||||
return
|
||||
case err != nil:
|
||||
// Log the error and then retry the connection.
|
||||
s.ErrorLog.ErrorContext(ctx, "Error processing mapping stream from management server, retrying connection after backoff.",
|
||||
"backoff", backoffDuration,
|
||||
"error", err)
|
||||
default:
|
||||
// TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues.
|
||||
s.ErrorLog.ErrorContext(ctx, "Management mapping connection terminated by the server, retrying connection after backoff.",
|
||||
"backoff", backoffDuration)
|
||||
}
|
||||
time.Sleep(backoffDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient) error {
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Shutting down.
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := mappingClient.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
// Mapping connection gracefully terminated by server.
|
||||
return nil
|
||||
case err != nil:
|
||||
// Something has gone horribly wrong, return and hope the parent retries the connection.
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
// Process msg updates sequentially to avoid conflict, so block
|
||||
// additional receiving until this processing is completed.
|
||||
for _, mapping := range msg.GetMapping() {
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
|
||||
s.ErrorLog.ErrorContext(ctx, "Error adding new mapping, ignoring this mapping and continuing processing.",
|
||||
"service_id", mapping.GetId(),
|
||||
"domain", mapping.GetDomain(),
|
||||
"error", err)
|
||||
}
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
||||
s.updateMapping(ctx, mapping)
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
||||
s.removeMapping(mapping)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
if err := s.netbird.AddPeer(mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
|
||||
return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err)
|
||||
}
|
||||
if s.acme != nil {
|
||||
s.acme.AddDomain(mapping.GetDomain())
|
||||
}
|
||||
|
||||
// Pass the mapping through to the update function to avoid duplicating the
|
||||
// setup, currently update is simply a subset of this function, so this
|
||||
// separation makes sense...to me at least.
|
||||
s.updateMapping(ctx, mapping)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
// Very simple implementation here, we don't touch the existing peer
|
||||
// connection or any existing TLS configuration, we simply overwrite
|
||||
// the auth and proxy mappings.
|
||||
// Note: this does require the management server to always send a
|
||||
// full mapping rather than deltas during a modification.
|
||||
var schemes []auth.Scheme
|
||||
if mapping.GetAuth().GetBasic().GetEnabled() {
|
||||
schemes = append(schemes, auth.NewBasicAuth(
|
||||
mapping.GetAuth().GetBasic().GetUsername(),
|
||||
mapping.GetAuth().GetBasic().GetPassword(),
|
||||
))
|
||||
}
|
||||
if mapping.GetAuth().GetPin().GetEnabled() {
|
||||
schemes = append(schemes, auth.NewPin(
|
||||
mapping.GetAuth().GetPin().GetPin(),
|
||||
))
|
||||
}
|
||||
if mapping.GetAuth().GetOidc().GetEnabled() {
|
||||
oidc := mapping.GetAuth().GetOidc()
|
||||
scheme, err := auth.NewOIDC(ctx, auth.OIDCConfig{
|
||||
OIDCProviderURL: oidc.GetOidcProviderUrl(),
|
||||
OIDCClientID: oidc.GetOidcClientId(),
|
||||
OIDCClientSecret: oidc.GetOidcClientSecret(),
|
||||
OIDCRedirectURL: oidc.GetOidcRedirectUrl(),
|
||||
OIDCScopes: oidc.GetOidcScopes(),
|
||||
})
|
||||
if err != nil {
|
||||
s.ErrorLog.Error("Failed to create OIDC scheme", "error", err)
|
||||
} else {
|
||||
schemes = append(schemes, scheme)
|
||||
}
|
||||
}
|
||||
s.auth.AddDomain(mapping.GetDomain(), schemes)
|
||||
s.proxy.AddMapping(s.protoToMapping(mapping))
|
||||
}
|
||||
|
||||
func (s *Server) removeMapping(mapping *proto.ProxyMapping) {
|
||||
s.netbird.RemovePeer(mapping.GetDomain())
|
||||
if s.acme != nil {
|
||||
s.acme.RemoveDomain(mapping.GetDomain())
|
||||
}
|
||||
s.auth.RemoveDomain(mapping.GetDomain())
|
||||
s.proxy.RemoveMapping(s.protoToMapping(mapping))
|
||||
}
|
||||
|
||||
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
paths := make(map[string]*url.URL)
|
||||
for _, pathMapping := range mapping.GetPath() {
|
||||
targetURL, err := url.Parse(pathMapping.GetTarget())
|
||||
if err != nil {
|
||||
// TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure?
|
||||
s.ErrorLog.Error("Error parsing target URL for path, this path will be ignored but other paths will still be configured.",
|
||||
"service_id", mapping.GetId(),
|
||||
"domain", mapping.GetDomain(),
|
||||
"path", pathMapping.GetPath(),
|
||||
"target", pathMapping.GetTarget(),
|
||||
"error", err)
|
||||
}
|
||||
paths[pathMapping.GetPath()] = targetURL
|
||||
}
|
||||
return proxy.Mapping{
|
||||
ID: mapping.GetId(),
|
||||
Host: mapping.GetDomain(),
|
||||
Paths: paths,
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,27 +9,81 @@ import "google/protobuf/timestamp.proto";
|
||||
// ProxyService - Management is the SERVER, Proxy is the CLIENT
|
||||
// Proxy initiates connection to management
|
||||
service ProxyService {
|
||||
// Bidirectional stream for proxy-management communication
|
||||
rpc Stream(stream ProxyMessage) returns (stream ManagementMessage);
|
||||
rpc GetMappingUpdate(GetMappingUpdateRequest) returns (stream GetMappingUpdateResponse);
|
||||
|
||||
rpc SendAccessLog(SendAccessLogRequest) returns (SendAccessLogResponse);
|
||||
}
|
||||
|
||||
// Messages FROM Proxy TO Management
|
||||
message ProxyMessage {
|
||||
oneof payload {
|
||||
ProxyHello hello = 1; // First message on connect
|
||||
ProxyRequestData request_data = 2; // Real-time access logs
|
||||
}
|
||||
}
|
||||
|
||||
// Proxy identification on connect
|
||||
message ProxyHello {
|
||||
// GetMappingUpdateRequest is sent to initialise a mapping stream.
|
||||
message GetMappingUpdateRequest {
|
||||
string proxy_id = 1;
|
||||
string version = 2;
|
||||
google.protobuf.Timestamp started_at = 3;
|
||||
}
|
||||
|
||||
// Access log from proxy to management
|
||||
message ProxyRequestData {
|
||||
// GetMappingUpdateResponse contains zero or more ProxyMappings.
|
||||
// No mappings may be sent to test the liveness of the Proxy.
|
||||
// Mappings that are sent should be interpreted by the Proxy appropriately.
|
||||
message GetMappingUpdateResponse {
|
||||
repeated ProxyMapping mapping = 1;
|
||||
}
|
||||
|
||||
enum ProxyMappingUpdateType {
|
||||
UPDATE_TYPE_CREATED = 0;
|
||||
UPDATE_TYPE_MODIFIED = 1;
|
||||
UPDATE_TYPE_REMOVED = 2;
|
||||
}
|
||||
|
||||
message PathMapping {
|
||||
string path = 1;
|
||||
string target = 2;
|
||||
}
|
||||
|
||||
message Authentication {
|
||||
HTTPBasic basic = 1;
|
||||
Pin pin = 2;
|
||||
OIDC oidc = 3;
|
||||
}
|
||||
|
||||
message HTTPBasic {
|
||||
bool enabled = 1;
|
||||
string username = 2;
|
||||
string password = 3;
|
||||
}
|
||||
|
||||
message Pin {
|
||||
bool enabled = 1;
|
||||
string pin = 2;
|
||||
}
|
||||
|
||||
message OIDC {
|
||||
bool enabled = 1;
|
||||
string oidc_provider_url = 2;
|
||||
string oidc_client_id = 3;
|
||||
string oidc_client_secret = 4;
|
||||
string oidc_redirect_url = 5;
|
||||
repeated string oidc_scopes = 6;
|
||||
string session_cookie_name = 7;
|
||||
}
|
||||
|
||||
message ProxyMapping {
|
||||
ProxyMappingUpdateType type = 1;
|
||||
string id = 2;
|
||||
string domain = 3;
|
||||
repeated PathMapping path = 4;
|
||||
string setup_key = 5;
|
||||
Authentication auth = 6;
|
||||
}
|
||||
|
||||
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
|
||||
message SendAccessLogRequest {
|
||||
AccessLog log = 1;
|
||||
}
|
||||
|
||||
// SendAccessLogResponse is intentionally empty to allow for future expansion.
|
||||
message SendAccessLogResponse {}
|
||||
|
||||
message AccessLog {
|
||||
google.protobuf.Timestamp timestamp = 1;
|
||||
string service_id = 2;
|
||||
string host = 3;
|
||||
@@ -42,61 +96,3 @@ message ProxyRequestData {
|
||||
string user_id = 10;
|
||||
bool auth_success = 11;
|
||||
}
|
||||
|
||||
// Messages FROM Management TO Proxy
|
||||
message ManagementMessage {
|
||||
oneof payload {
|
||||
ServicesSnapshot snapshot = 1; // Full snapshot on initial connect
|
||||
ServiceUpdate update = 2; // Incremental service update
|
||||
}
|
||||
}
|
||||
|
||||
// Full snapshot of all services for this proxy
|
||||
message ServicesSnapshot {
|
||||
repeated ExposedServiceConfig services = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
}
|
||||
|
||||
// Incremental service update
|
||||
message ServiceUpdate {
|
||||
enum UpdateType {
|
||||
CREATED = 0;
|
||||
UPDATED = 1;
|
||||
REMOVED = 2;
|
||||
}
|
||||
UpdateType type = 1;
|
||||
ExposedServiceConfig service = 2; // Set for CREATED and UPDATED
|
||||
string service_id = 3; // Service ID (always set)
|
||||
}
|
||||
|
||||
// Exposed service configuration
|
||||
message ExposedServiceConfig {
|
||||
string id = 1;
|
||||
string domain = 2;
|
||||
map<string, string> path_mappings = 3; // path -> target
|
||||
string setup_key = 4;
|
||||
AuthConfig auth = 5;
|
||||
}
|
||||
|
||||
// Authentication configuration
|
||||
message AuthConfig {
|
||||
oneof auth_type {
|
||||
BasicAuthConfig basic_auth = 1;
|
||||
PinAuthConfig pin_auth = 2;
|
||||
BearerAuthConfig bearer_auth = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message BasicAuthConfig {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message PinAuthConfig {
|
||||
string pin = 1;
|
||||
string header = 2;
|
||||
}
|
||||
|
||||
message BearerAuthConfig {
|
||||
bool enabled = 1;
|
||||
}
|
||||
@@ -18,8 +18,8 @@ const _ = grpc.SupportPackageIsVersion7
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ProxyServiceClient interface {
|
||||
// Bidirectional stream for proxy-management communication
|
||||
Stream(ctx context.Context, opts ...grpc.CallOption) (ProxyService_StreamClient, error)
|
||||
GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error)
|
||||
SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error)
|
||||
}
|
||||
|
||||
type proxyServiceClient struct {
|
||||
@@ -30,43 +30,53 @@ func NewProxyServiceClient(cc grpc.ClientConnInterface) ProxyServiceClient {
|
||||
return &proxyServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *proxyServiceClient) Stream(ctx context.Context, opts ...grpc.CallOption) (ProxyService_StreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[0], "/management.ProxyService/Stream", opts...)
|
||||
func (c *proxyServiceClient) GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[0], "/management.ProxyService/GetMappingUpdate", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &proxyServiceStreamClient{stream}
|
||||
x := &proxyServiceGetMappingUpdateClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type ProxyService_StreamClient interface {
|
||||
Send(*ProxyMessage) error
|
||||
Recv() (*ManagementMessage, error)
|
||||
type ProxyService_GetMappingUpdateClient interface {
|
||||
Recv() (*GetMappingUpdateResponse, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type proxyServiceStreamClient struct {
|
||||
type proxyServiceGetMappingUpdateClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *proxyServiceStreamClient) Send(m *ProxyMessage) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *proxyServiceStreamClient) Recv() (*ManagementMessage, error) {
|
||||
m := new(ManagementMessage)
|
||||
func (x *proxyServiceGetMappingUpdateClient) Recv() (*GetMappingUpdateResponse, error) {
|
||||
m := new(GetMappingUpdateResponse)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *proxyServiceClient) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) {
|
||||
out := new(SendAccessLogResponse)
|
||||
err := c.cc.Invoke(ctx, "/management.ProxyService/SendAccessLog", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ProxyServiceServer is the server API for ProxyService service.
|
||||
// All implementations must embed UnimplementedProxyServiceServer
|
||||
// for forward compatibility
|
||||
type ProxyServiceServer interface {
|
||||
// Bidirectional stream for proxy-management communication
|
||||
Stream(ProxyService_StreamServer) error
|
||||
GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error
|
||||
SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error)
|
||||
mustEmbedUnimplementedProxyServiceServer()
|
||||
}
|
||||
|
||||
@@ -74,8 +84,11 @@ type ProxyServiceServer interface {
|
||||
type UnimplementedProxyServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedProxyServiceServer) Stream(ProxyService_StreamServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Stream not implemented")
|
||||
func (UnimplementedProxyServiceServer) GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method GetMappingUpdate not implemented")
|
||||
}
|
||||
func (UnimplementedProxyServiceServer) SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SendAccessLog not implemented")
|
||||
}
|
||||
func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {}
|
||||
|
||||
@@ -90,30 +103,43 @@ func RegisterProxyServiceServer(s grpc.ServiceRegistrar, srv ProxyServiceServer)
|
||||
s.RegisterService(&ProxyService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ProxyService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(ProxyServiceServer).Stream(&proxyServiceStreamServer{stream})
|
||||
func _ProxyService_GetMappingUpdate_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(GetMappingUpdateRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(ProxyServiceServer).GetMappingUpdate(m, &proxyServiceGetMappingUpdateServer{stream})
|
||||
}
|
||||
|
||||
type ProxyService_StreamServer interface {
|
||||
Send(*ManagementMessage) error
|
||||
Recv() (*ProxyMessage, error)
|
||||
type ProxyService_GetMappingUpdateServer interface {
|
||||
Send(*GetMappingUpdateResponse) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type proxyServiceStreamServer struct {
|
||||
type proxyServiceGetMappingUpdateServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *proxyServiceStreamServer) Send(m *ManagementMessage) error {
|
||||
func (x *proxyServiceGetMappingUpdateServer) Send(m *GetMappingUpdateResponse) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *proxyServiceStreamServer) Recv() (*ProxyMessage, error) {
|
||||
m := new(ProxyMessage)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
func _ProxyService_SendAccessLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SendAccessLogRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
if interceptor == nil {
|
||||
return srv.(ProxyServiceServer).SendAccessLog(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ProxyService/SendAccessLog",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ProxyServiceServer).SendAccessLog(ctx, req.(*SendAccessLogRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service.
|
||||
@@ -122,13 +148,17 @@ func (x *proxyServiceStreamServer) Recv() (*ProxyMessage, error) {
|
||||
var ProxyService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "management.ProxyService",
|
||||
HandlerType: (*ProxyServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "SendAccessLog",
|
||||
Handler: _ProxyService_SendAccessLog_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Stream",
|
||||
Handler: _ProxyService_Stream_Handler,
|
||||
StreamName: "GetMappingUpdate",
|
||||
Handler: _ProxyService_GetMappingUpdate_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "proxy_service.proto",
|
||||
|
||||
Reference in New Issue
Block a user