Merge branch 'docker-network-checking' of github.com:JonnyBooker/newt into JonnyBooker-docker-network-checking

This commit is contained in:
Owen
2025-06-22 11:34:30 -04:00
3 changed files with 214 additions and 51 deletions

View File

@@ -39,6 +39,7 @@ When Newt receives WireGuard control messages, it will use the information encod
- `updown` (optional): A script to be called when targets are added or removed. - `updown` (optional): A script to be called when targets are added or removed.
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls) - `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
- `docker-socket` (optional): Set the Docker socket to use the container discovery integration - `docker-socket` (optional): Set the Docker socket to use the container discovery integration
- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process
- Example: - Example:
@@ -100,6 +101,26 @@ services:
- DOCKER_SOCKET=/var/run/docker.sock - DOCKER_SOCKET=/var/run/docker.sock
``` ```
#### Hostnames vs IPs
When the Docker Socket Integration is used, depending on the network which Newt is run with, either the hostname (generally considered the container name) or the IP address of the container will be sent to Pangolin. Here are some of the scenarios where IPs or hostname of the container will be utilised:
- **Running in Network Mode 'host'**: IP addresses will be used
- **Running in Network Mode 'bridge'**: IP addresses will be used
- **Running in docker-compose without a network specification**: Docker compose creates a network for the compose by default, hostnames will be used
- **Running on docker-compose with defined network**: Hostnames will be used
### Docker Enforce Network Validation
When run as a Docker container, Newt can validate that the target being provided is on the same network as the Newt container and only return containers directly accessible by Newt. Validation will be carried out against either the hostname/IP Address and the Port number to ensure the running container is exposing the ports to Newt.
It is important to note that if the Newt container is run with a network mode of `host` that this feature will not work. Running in `host` mode causes the container to share its resources with the host machine, therefore making it so the specific host container information for Newt cannot be retrieved to be able to carry out network validation.
**Configuration:**
Validation is `false` by default. It can be enabled via setting the `--docker-enforce-network-validation` CLI argument or by setting the `DOCKER_ENFORCE_NETWORK_VALIDATION` environment variable.
If validation is enforced and the Docker socket is available, Newt will **not** add the target as it cannot be verified. A warning will be presented in the Newt logs.
### Updown ### Updown
You can pass in a updown script for Newt to call when it is adding or removing a target: You can pass in a updown script for Newt to call when it is adding or removing a target:

View File

@@ -4,10 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"strconv"
"strings" "strings"
"time" "time"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
) )
@@ -67,13 +70,61 @@ func CheckSocket(socketPath string) bool {
return true return true
} }
// IsWithinHostNetwork checks if a provided target is within the host container network
func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int) (bool, error) {
// Always enforce network validation
containers, err := ListContainers(socketPath, true)
if err != nil {
return false, err
}
// Determine if given an IP address
var parsedTargetAddressIp = net.ParseIP(targetAddress)
// If we can find the passed hostname/IP address in the networks or as the container name, it is valid and can add it
for _, c := range containers {
for _, network := range c.Networks {
// If the target address is not an IP address, use the container name
if parsedTargetAddressIp == nil {
if c.Name == targetAddress {
for _, port := range c.Ports {
if port.PublicPort == targetPort || port.PrivatePort == targetPort {
return true, nil
}
}
}
} else {
//If the IP address matches, check the ports being mapped too
if network.IPAddress == targetAddress {
for _, port := range c.Ports {
if port.PublicPort == targetPort || port.PrivatePort == targetPort {
return true, nil
}
}
}
}
}
}
combinedTargetAddress := targetAddress + ":" + strconv.Itoa(targetPort)
return false, fmt.Errorf("target address not within host container network: %s", combinedTargetAddress)
}
// ListContainers lists all Docker containers with their network information // ListContainers lists all Docker containers with their network information
func ListContainers(socketPath string) ([]Container, error) { func ListContainers(socketPath string, enforceNetworkValidation bool) ([]Container, error) {
// Use the provided socket path or default to standard location // Use the provided socket path or default to standard location
if socketPath == "" { if socketPath == "" {
socketPath = "/var/run/docker.sock" socketPath = "/var/run/docker.sock"
} }
// Used to filter down containers returned to Pangolin
containerFilters := filters.NewArgs()
// Used to determine if we will send IP addresses or hostnames to Pangolin
useContainerIpAddresses := true
hostContainerId := ""
// Create a new Docker client // Create a new Docker client
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
@@ -86,16 +137,54 @@ func ListContainers(socketPath string) ([]Container, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create Docker client: %v", err) return nil, fmt.Errorf("failed to create Docker client: %v", err)
} }
defer cli.Close() defer cli.Close()
hostContainer, err := getHostContainer(ctx, cli)
if enforceNetworkValidation && err != nil {
return nil, fmt.Errorf("network validation enforced, cannot validate due to: %v", err)
}
// We may not be able to get back host container in scenarios like running the container in network mode 'host'
if hostContainer != nil {
// We can use the host container to filter out the list of returned containers
hostContainerId = hostContainer.ID
for hostContainerNetworkName := range hostContainer.NetworkSettings.Networks {
// If we're enforcing network validation, we'll filter on the host containers networks
if enforceNetworkValidation {
containerFilters.Add("network", hostContainerNetworkName)
}
// If the container is on the docker bridge network, we will use IP addresses over hostnames
if useContainerIpAddresses && hostContainerNetworkName != "bridge" {
useContainerIpAddresses = false
}
}
}
// List containers // List containers
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) containers, err := cli.ContainerList(ctx, container.ListOptions{All: true, Filters: containerFilters})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list containers: %v", err) return nil, fmt.Errorf("failed to list containers: %v", err)
} }
var dockerContainers []Container var dockerContainers []Container
for _, c := range containers { for _, c := range containers {
// Short ID like docker ps
shortId := c.ID[:12]
// Skip host container if set
if hostContainerId != "" && c.ID == hostContainerId {
continue
}
// Get container name (remove leading slash)
name := ""
if len(c.Names) > 0 {
name = strings.TrimPrefix(c.Names[0], "/")
}
// Convert ports // Convert ports
var ports []Port var ports []Port
for _, port := range c.Ports { for _, port := range c.Ports {
@@ -112,44 +201,36 @@ func ListContainers(socketPath string) ([]Container, error) {
ports = append(ports, dockerPort) ports = append(ports, dockerPort)
} }
// Get container name (remove leading slash)
name := ""
if len(c.Names) > 0 {
name = strings.TrimPrefix(c.Names[0], "/")
}
// Get network information by inspecting the container // Get network information by inspecting the container
networks := make(map[string]Network) networks := make(map[string]Network)
// Inspect container to get detailed network information // Extract network information from inspection
containerInfo, err := cli.ContainerInspect(ctx, c.ID) if c.NetworkSettings != nil && c.NetworkSettings.Networks != nil {
if err != nil { for networkName, endpoint := range c.NetworkSettings.Networks {
logger.Debug("Failed to inspect container %s for network info: %v", c.ID[:12], err) dockerNetwork := Network{
// Continue without network info if inspection fails NetworkID: endpoint.NetworkID,
} else { EndpointID: endpoint.EndpointID,
// Extract network information from inspection Gateway: endpoint.Gateway,
if containerInfo.NetworkSettings != nil && containerInfo.NetworkSettings.Networks != nil { IPPrefixLen: endpoint.IPPrefixLen,
for networkName, endpoint := range containerInfo.NetworkSettings.Networks { IPv6Gateway: endpoint.IPv6Gateway,
dockerNetwork := Network{ GlobalIPv6Address: endpoint.GlobalIPv6Address,
NetworkID: endpoint.NetworkID, GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen,
EndpointID: endpoint.EndpointID, MacAddress: endpoint.MacAddress,
Gateway: endpoint.Gateway, Aliases: endpoint.Aliases,
IPAddress: endpoint.IPAddress, DNSNames: endpoint.DNSNames,
IPPrefixLen: endpoint.IPPrefixLen,
IPv6Gateway: endpoint.IPv6Gateway,
GlobalIPv6Address: endpoint.GlobalIPv6Address,
GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen,
MacAddress: endpoint.MacAddress,
Aliases: endpoint.Aliases,
DNSNames: endpoint.DNSNames,
}
networks[networkName] = dockerNetwork
} }
// Use IPs over hostnames/containers as we're on the bridge network
if useContainerIpAddresses {
dockerNetwork.IPAddress = endpoint.IPAddress
}
networks[networkName] = dockerNetwork
} }
} }
dockerContainer := Container{ dockerContainer := Container{
ID: c.ID[:12], // Show short ID like docker ps ID: shortId,
Name: name, Name: name,
Image: c.Image, Image: c.Image,
State: c.State, State: c.State,
@@ -159,8 +240,26 @@ func ListContainers(socketPath string) ([]Container, error) {
Created: c.Created, Created: c.Created,
Networks: networks, Networks: networks,
} }
dockerContainers = append(dockerContainers, dockerContainer) dockerContainers = append(dockerContainers, dockerContainer)
} }
return dockerContainers, nil return dockerContainers, nil
} }
// getHostContainer gets the current container for the current host if possible
func getHostContainer(dockerContext context.Context, dockerClient *client.Client) (*container.InspectResponse, error) {
// Get hostname from the os
hostContainerName, err := os.Hostname()
if err != nil {
return nil, fmt.Errorf("failed to find hostname for container")
}
// Get host container from the docker socket
hostContainer, err := dockerClient.ContainerInspect(dockerContext, hostContainerName)
if err != nil {
return nil, fmt.Errorf("failed to find host container")
}
return &hostContainer, nil
}

81
main.go
View File

@@ -341,18 +341,20 @@ func resolveDomain(domain string) (string, error) {
} }
var ( var (
endpoint string endpoint string
id string id string
secret string secret string
mtu string mtu string
mtuInt int mtuInt int
dns string dns string
privateKey wgtypes.Key privateKey wgtypes.Key
err error err error
logLevel string logLevel string
updownScript string updownScript string
tlsPrivateKey string tlsPrivateKey string
dockerSocket string dockerSocket string
dockerEnforceNetworkValidation string
dockerEnforceNetworkValidationBool bool
) )
func main() { func main() {
@@ -366,6 +368,7 @@ func main() {
updownScript = os.Getenv("UPDOWN_SCRIPT") updownScript = os.Getenv("UPDOWN_SCRIPT")
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET") dockerSocket = os.Getenv("DOCKER_SOCKET")
dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -394,6 +397,9 @@ func main() {
if dockerSocket == "" { if dockerSocket == "" {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
} }
if dockerEnforceNetworkValidation == "" {
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -418,6 +424,13 @@ func main() {
logger.Fatal("Failed to parse MTU: %v", err) logger.Fatal("Failed to parse MTU: %v", err)
} }
// parse if we want to enforce container network validation
dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation)
if err != nil {
logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'")
dockerEnforceNetworkValidationBool = false
}
privateKey, err = wgtypes.GeneratePrivateKey() privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
@@ -437,6 +450,24 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// output env var values if set
logger.Debug("Endpoint: %v", endpoint)
logger.Debug("Log Level: %v", logLevel)
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "")
if dns != "" {
logger.Debug("Dns: %v", dns)
}
if dockerSocket != "" {
logger.Debug("Docker Socket: %v", dockerSocket)
}
if mtu != "" {
logger.Debug("MTU: %v", mtu)
}
if updownScript != "" {
logger.Debug("Up Down Script: %v", updownScript)
}
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net
@@ -676,7 +707,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
} }
// List Docker containers // List Docker containers
containers, err := docker.ListContainers(dockerSocket) containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool)
if err != nil { if err != nil {
logger.Error("Failed to list Docker containers: %v", err) logger.Error("Failed to list Docker containers: %v", err)
return return
@@ -760,12 +791,14 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
if action == "add" { if action == "add" {
target := parts[1] + ":" + parts[2] targetAddress := parts[1]
targetPort, _ := strconv.Atoi(parts[2])
combinedAddress := targetAddress + ":" + parts[2]
// Call updown script if provided // Call updown script if provided
processedTarget := target processedTarget := combinedAddress
if updownScript != "" { if updownScript != "" {
newTarget, err := executeUpdownScript(action, proto, target) newTarget, err := executeUpdownScript(action, proto, combinedAddress)
if err != nil { if err != nil {
logger.Warn("Updown script error: %v", err) logger.Warn("Updown script error: %v", err)
} else if newTarget != "" { } else if newTarget != "" {
@@ -782,9 +815,19 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
} }
// Add the new target // If docker network validation is enabled
pm.AddTarget(proto, tunnelIP, port, processedTarget) if dockerEnforceNetworkValidationBool {
// If the target address is within the host container network, the target will be added
isWithinHostContainerNetwork, err := docker.IsWithinHostNetwork(dockerSocket, targetAddress, targetPort)
if !isWithinHostContainerNetwork {
logger.Warn("Not adding target address: %v", err)
} else {
pm.AddTarget(proto, tunnelIP, port, processedTarget)
}
} else {
// If we're not enforcing network validation, just proceed with adding the target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
}
} else if action == "remove" { } else if action == "remove" {
logger.Info("Removing target with port %d", port) logger.Info("Removing target with port %d", port)