diff --git a/README.md b/README.md index a31e9a1..9d88096 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ When Newt receives WireGuard control messages, it will use the information encod - `updown` (optional): A script to be called when targets are added or removed. - `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls) - `docker-socket` (optional): Set the Docker socket to use the container discovery integration +- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process - Example: @@ -100,6 +101,26 @@ services: - DOCKER_SOCKET=/var/run/docker.sock ``` +#### Hostnames vs IPs + +When the Docker Socket Integration is used, depending on the network which Newt is run with, either the hostname (generally considered the container name) or the IP address of the container will be sent to Pangolin. Here are some of the scenarios where IPs or hostname of the container will be utilised: +- **Running in Network Mode 'host'**: IP addresses will be used +- **Running in Network Mode 'bridge'**: IP addresses will be used +- **Running in docker-compose without a network specification**: Docker compose creates a network for the compose by default, hostnames will be used +- **Running on docker-compose with defined network**: Hostnames will be used + +### Docker Enforce Network Validation + +When run as a Docker container, Newt can validate that the target being provided is on the same network as the Newt container and only return containers directly accessible by Newt. Validation will be carried out against either the hostname/IP Address and the Port number to ensure the running container is exposing the ports to Newt. + +It is important to note that if the Newt container is run with a network mode of `host` that this feature will not work. Running in `host` mode causes the container to share its resources with the host machine, therefore making it so the specific host container information for Newt cannot be retrieved to be able to carry out network validation. + +**Configuration:** + +Validation is `false` by default. It can be enabled via setting the `--docker-enforce-network-validation` CLI argument or by setting the `DOCKER_ENFORCE_NETWORK_VALIDATION` environment variable. + +If validation is enforced and the Docker socket is available, Newt will **not** add the target as it cannot be verified. A warning will be presented in the Newt logs. + ### Updown You can pass in a updown script for Newt to call when it is adding or removing a target: diff --git a/docker/client.go b/docker/client.go index 98936fe..9fedf52 100644 --- a/docker/client.go +++ b/docker/client.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "net" + "os" + "strconv" "strings" "time" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" "github.com/fosrl/newt/logger" ) @@ -67,13 +70,61 @@ func CheckSocket(socketPath string) bool { return true } +// IsWithinHostNetwork checks if a provided target is within the host container network +func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int) (bool, error) { + // Always enforce network validation + containers, err := ListContainers(socketPath, true) + if err != nil { + + return false, err + } + + // Determine if given an IP address + var parsedTargetAddressIp = net.ParseIP(targetAddress) + + // If we can find the passed hostname/IP address in the networks or as the container name, it is valid and can add it + for _, c := range containers { + for _, network := range c.Networks { + // If the target address is not an IP address, use the container name + if parsedTargetAddressIp == nil { + if c.Name == targetAddress { + for _, port := range c.Ports { + if port.PublicPort == targetPort || port.PrivatePort == targetPort { + return true, nil + } + } + } + } else { + //If the IP address matches, check the ports being mapped too + if network.IPAddress == targetAddress { + for _, port := range c.Ports { + if port.PublicPort == targetPort || port.PrivatePort == targetPort { + return true, nil + } + } + } + } + } + } + + combinedTargetAddress := targetAddress + ":" + strconv.Itoa(targetPort) + return false, fmt.Errorf("target address not within host container network: %s", combinedTargetAddress) +} + // ListContainers lists all Docker containers with their network information -func ListContainers(socketPath string) ([]Container, error) { +func ListContainers(socketPath string, enforceNetworkValidation bool) ([]Container, error) { // Use the provided socket path or default to standard location if socketPath == "" { socketPath = "/var/run/docker.sock" } + // Used to filter down containers returned to Pangolin + containerFilters := filters.NewArgs() + + // Used to determine if we will send IP addresses or hostnames to Pangolin + useContainerIpAddresses := true + hostContainerId := "" + // Create a new Docker client ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -86,16 +137,54 @@ func ListContainers(socketPath string) ([]Container, error) { if err != nil { return nil, fmt.Errorf("failed to create Docker client: %v", err) } + defer cli.Close() + hostContainer, err := getHostContainer(ctx, cli) + if enforceNetworkValidation && err != nil { + return nil, fmt.Errorf("network validation enforced, cannot validate due to: %v", err) + } + + // We may not be able to get back host container in scenarios like running the container in network mode 'host' + if hostContainer != nil { + // We can use the host container to filter out the list of returned containers + hostContainerId = hostContainer.ID + + for hostContainerNetworkName := range hostContainer.NetworkSettings.Networks { + // If we're enforcing network validation, we'll filter on the host containers networks + if enforceNetworkValidation { + containerFilters.Add("network", hostContainerNetworkName) + } + + // If the container is on the docker bridge network, we will use IP addresses over hostnames + if useContainerIpAddresses && hostContainerNetworkName != "bridge" { + useContainerIpAddresses = false + } + } + } + // List containers - containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) + containers, err := cli.ContainerList(ctx, container.ListOptions{All: true, Filters: containerFilters}) if err != nil { return nil, fmt.Errorf("failed to list containers: %v", err) } var dockerContainers []Container for _, c := range containers { + // Short ID like docker ps + shortId := c.ID[:12] + + // Skip host container if set + if hostContainerId != "" && c.ID == hostContainerId { + continue + } + + // Get container name (remove leading slash) + name := "" + if len(c.Names) > 0 { + name = strings.TrimPrefix(c.Names[0], "/") + } + // Convert ports var ports []Port for _, port := range c.Ports { @@ -112,44 +201,36 @@ func ListContainers(socketPath string) ([]Container, error) { ports = append(ports, dockerPort) } - // Get container name (remove leading slash) - name := "" - if len(c.Names) > 0 { - name = strings.TrimPrefix(c.Names[0], "/") - } - // Get network information by inspecting the container networks := make(map[string]Network) - // Inspect container to get detailed network information - containerInfo, err := cli.ContainerInspect(ctx, c.ID) - if err != nil { - logger.Debug("Failed to inspect container %s for network info: %v", c.ID[:12], err) - // Continue without network info if inspection fails - } else { - // Extract network information from inspection - if containerInfo.NetworkSettings != nil && containerInfo.NetworkSettings.Networks != nil { - for networkName, endpoint := range containerInfo.NetworkSettings.Networks { - dockerNetwork := Network{ - NetworkID: endpoint.NetworkID, - EndpointID: endpoint.EndpointID, - Gateway: endpoint.Gateway, - IPAddress: endpoint.IPAddress, - IPPrefixLen: endpoint.IPPrefixLen, - IPv6Gateway: endpoint.IPv6Gateway, - GlobalIPv6Address: endpoint.GlobalIPv6Address, - GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen, - MacAddress: endpoint.MacAddress, - Aliases: endpoint.Aliases, - DNSNames: endpoint.DNSNames, - } - networks[networkName] = dockerNetwork + // Extract network information from inspection + if c.NetworkSettings != nil && c.NetworkSettings.Networks != nil { + for networkName, endpoint := range c.NetworkSettings.Networks { + dockerNetwork := Network{ + NetworkID: endpoint.NetworkID, + EndpointID: endpoint.EndpointID, + Gateway: endpoint.Gateway, + IPPrefixLen: endpoint.IPPrefixLen, + IPv6Gateway: endpoint.IPv6Gateway, + GlobalIPv6Address: endpoint.GlobalIPv6Address, + GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen, + MacAddress: endpoint.MacAddress, + Aliases: endpoint.Aliases, + DNSNames: endpoint.DNSNames, } + + // Use IPs over hostnames/containers as we're on the bridge network + if useContainerIpAddresses { + dockerNetwork.IPAddress = endpoint.IPAddress + } + + networks[networkName] = dockerNetwork } } dockerContainer := Container{ - ID: c.ID[:12], // Show short ID like docker ps + ID: shortId, Name: name, Image: c.Image, State: c.State, @@ -159,8 +240,26 @@ func ListContainers(socketPath string) ([]Container, error) { Created: c.Created, Networks: networks, } + dockerContainers = append(dockerContainers, dockerContainer) } return dockerContainers, nil } + +// getHostContainer gets the current container for the current host if possible +func getHostContainer(dockerContext context.Context, dockerClient *client.Client) (*container.InspectResponse, error) { + // Get hostname from the os + hostContainerName, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("failed to find hostname for container") + } + + // Get host container from the docker socket + hostContainer, err := dockerClient.ContainerInspect(dockerContext, hostContainerName) + if err != nil { + return nil, fmt.Errorf("failed to find host container") + } + + return &hostContainer, nil +} \ No newline at end of file diff --git a/main.go b/main.go index fdece97..6622999 100644 --- a/main.go +++ b/main.go @@ -341,18 +341,20 @@ func resolveDomain(domain string) (string, error) { } var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - updownScript string - tlsPrivateKey string - dockerSocket string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + updownScript string + tlsPrivateKey string + dockerSocket string + dockerEnforceNetworkValidation string + dockerEnforceNetworkValidationBool bool ) func main() { @@ -366,6 +368,7 @@ func main() { updownScript = os.Getenv("UPDOWN_SCRIPT") tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") + dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -394,6 +397,9 @@ func main() { if dockerSocket == "" { flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") } + if dockerEnforceNetworkValidation == "" { + flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -418,6 +424,13 @@ func main() { logger.Fatal("Failed to parse MTU: %v", err) } + // parse if we want to enforce container network validation + dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation) + if err != nil { + logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'") + dockerEnforceNetworkValidationBool = false + } + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) @@ -437,6 +450,24 @@ func main() { logger.Fatal("Failed to create client: %v", err) } + // output env var values if set + logger.Debug("Endpoint: %v", endpoint) + logger.Debug("Log Level: %v", logLevel) + logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool) + logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "") + if dns != "" { + logger.Debug("Dns: %v", dns) + } + if dockerSocket != "" { + logger.Debug("Docker Socket: %v", dockerSocket) + } + if mtu != "" { + logger.Debug("MTU: %v", mtu) + } + if updownScript != "" { + logger.Debug("Up Down Script: %v", updownScript) + } + // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -676,7 +707,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } // List Docker containers - containers, err := docker.ListContainers(dockerSocket) + containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool) if err != nil { logger.Error("Failed to list Docker containers: %v", err) return @@ -760,12 +791,14 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } if action == "add" { - target := parts[1] + ":" + parts[2] + targetAddress := parts[1] + targetPort, _ := strconv.Atoi(parts[2]) + combinedAddress := targetAddress + ":" + parts[2] // Call updown script if provided - processedTarget := target + processedTarget := combinedAddress if updownScript != "" { - newTarget, err := executeUpdownScript(action, proto, target) + newTarget, err := executeUpdownScript(action, proto, combinedAddress) if err != nil { logger.Warn("Updown script error: %v", err) } else if newTarget != "" { @@ -782,9 +815,19 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } } - // Add the new target - pm.AddTarget(proto, tunnelIP, port, processedTarget) - + // If docker network validation is enabled + if dockerEnforceNetworkValidationBool { + // If the target address is within the host container network, the target will be added + isWithinHostContainerNetwork, err := docker.IsWithinHostNetwork(dockerSocket, targetAddress, targetPort) + if !isWithinHostContainerNetwork { + logger.Warn("Not adding target address: %v", err) + } else { + pm.AddTarget(proto, tunnelIP, port, processedTarget) + } + } else { + // If we're not enforcing network validation, just proceed with adding the target + pm.AddTarget(proto, tunnelIP, port, processedTarget) + } } else if action == "remove" { logger.Info("Removing target with port %d", port)