mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 04:39:54 +00:00
Merge remote-tracking branch 'origin/main' into improve-usp-fw
# Conflicts: # client/firewall/uspfilter/conntrack/common.go # client/firewall/uspfilter/filter.go # client/firewall/uspfilter/forwarder/icmp.go # client/firewall/uspfilter/forwarder/tcp.go # client/firewall/uspfilter/nat.go
This commit is contained in:
@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestSetSessionCookieHasRootPath(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
setSessionCookie(w, "test-token", time.Hour)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
@@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error
|
||||
}
|
||||
|
||||
// PingTCP performs a TCP ping through a client.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
||||
// ipVersion may be "4", "6", or "" for automatic.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error {
|
||||
params := url.Values{}
|
||||
params.Set("host", host)
|
||||
params.Set("port", fmt.Sprintf("%d", port))
|
||||
if timeout != "" {
|
||||
params.Set("timeout", timeout)
|
||||
if timeout > 0 {
|
||||
params.Set("timeout", timeout.String())
|
||||
}
|
||||
if ipVersion != "" {
|
||||
params.Set("ip_version", ipVersion)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||
@@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int,
|
||||
|
||||
func (c *Client) printPingResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"]))
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
||||
remote, _ := data["remote"].(string)
|
||||
if remote != "" && remote != host {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %s\n", host)
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %s\n", host)
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
@@ -310,6 +322,76 @@ func (c *Client) printError(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// CaptureOptions configures a capture request.
|
||||
type CaptureOptions struct {
|
||||
AccountID string
|
||||
Duration string
|
||||
FilterExpr string
|
||||
Text bool
|
||||
Verbose bool
|
||||
ASCII bool
|
||||
Output io.Writer
|
||||
}
|
||||
|
||||
// Capture streams a packet capture from the debug endpoint. The response body
|
||||
// (pcap or text) is written directly to opts.Output until the server closes the
|
||||
// connection or the context is cancelled.
|
||||
func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error {
|
||||
if opts.AccountID == "" {
|
||||
return fmt.Errorf("account ID is required")
|
||||
}
|
||||
if opts.Output == nil {
|
||||
return fmt.Errorf("output writer is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
if opts.Duration != "" {
|
||||
params.Set("duration", opts.Duration)
|
||||
}
|
||||
if opts.FilterExpr != "" {
|
||||
params.Set("filter", opts.FilterExpr)
|
||||
}
|
||||
if opts.Text {
|
||||
params.Set("format", "text")
|
||||
}
|
||||
if opts.Verbose {
|
||||
params.Set("verbose", "true")
|
||||
}
|
||||
if opts.ASCII {
|
||||
params.Set("ascii", "true")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID))
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
|
||||
fullURL := c.baseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
// Use a separate client without timeout since captures stream for their full duration.
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
_, err = io.Copy(opts.Output, resp.Body)
|
||||
if err != nil && ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -174,6 +175,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
case "capture":
|
||||
h.handleCapture(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -525,13 +528,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
}
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" {
|
||||
network += v
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", host, port)
|
||||
address := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
start := time.Now()
|
||||
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
conn, err := client.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
@@ -541,18 +549,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
remote := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil {
|
||||
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"remote": remote,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
@@ -632,6 +644,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
})
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
//
|
||||
// Query params:
|
||||
//
|
||||
// duration: capture duration (0 or absent = max, capped at 30m)
|
||||
// format: "text" for human-readable output (default: pcap)
|
||||
// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443")
|
||||
func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "client not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
duration := maxCaptureDuration
|
||||
if durationStr := r.URL.Query().Get("duration"); durationStr != "" {
|
||||
d, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d < 0 {
|
||||
http.Error(w, "duration must not be negative", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d > 0 {
|
||||
duration = min(d, maxCaptureDuration)
|
||||
}
|
||||
}
|
||||
|
||||
filter := r.URL.Query().Get("filter")
|
||||
wantText := r.URL.Query().Get("format") == "text"
|
||||
verbose := r.URL.Query().Get("verbose") == "true"
|
||||
ascii := r.URL.Query().Get("ascii") == "true"
|
||||
|
||||
opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii}
|
||||
if wantText {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
opts.TextOutput = w
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap")
|
||||
w.Header().Set("Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID))
|
||||
opts.Output = w
|
||||
}
|
||||
|
||||
cs, err := client.StartCapture(opts)
|
||||
if err != nil {
|
||||
http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer cs.Stop()
|
||||
|
||||
// Flush headers after setup succeeds so errors above can still set status codes.
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(duration)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
cs.Stop()
|
||||
|
||||
stats := cs.Stats()
|
||||
h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped",
|
||||
accountID, stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
|
||||
Reference in New Issue
Block a user