From 8fc5677dfcfcc68a3aa23ef449181d9d47d9856a Mon Sep 17 00:00:00 2001 From: bolkedebruin Date: Thu, 30 Apr 2026 14:38:36 +0200 Subject: [PATCH] Harden three small panic / race paths (#186) Three independent corrections to functions that crash the worker on malformed or concurrent input: * protocol/common.readHeader: validate the on-wire size field before slicing data[8:size]. Introduce a named headerLen constant and reject declared sizes outside [headerLen, headerLen+maxFragmentSize], so a size of 0..7 (which previously panicked with slice bounds out of range) and an oversized size both surface as an error instead. * protocol/track: serialize access to the global Connections map with a sync.RWMutex. Concurrent RegisterTunnel/RemoveTunnel calls would otherwise be caught by the runtime as a fatal `concurrent map writes`. Also correct the inverted condition in Disconnect (the previous code dereferenced a nil Monitor when the id was missing and returned "does not exist" when the id was present). * web/ntlm.getAuthPayload: switch from authorisationEncoded[0:5] / [0:10] to strings.HasPrefix so an Authorization header shorter than the prefixes returns an error instead of a slice-bounds panic. Adds: - TestReadHeaderRejectsUndersizedSize (sizes 0/1/2/7). - TestTunnelTrackerConcurrent (200 goroutine pairs). - TestDisconnectKnownConnection / TestDisconnectMissingConnectionDoesNotPanic. - TestGetAuthPayloadShortHeader (missing/empty/3/4/5/9 character values). --- cmd/rdpgw/protocol/common.go | 19 ++++--- cmd/rdpgw/protocol/common_test.go | 25 +++++++++ cmd/rdpgw/protocol/track.go | 55 +++++++------------ cmd/rdpgw/protocol/track_test.go | 87 +++++++++++++++++++++++++++++++ cmd/rdpgw/web/ntlm.go | 26 +++++---- cmd/rdpgw/web/ntlm_test.go | 45 ++++++++++++++++ 6 files changed, 205 insertions(+), 52 deletions(-) create mode 100644 cmd/rdpgw/protocol/track_test.go create mode 100644 cmd/rdpgw/web/ntlm_test.go diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index 4509f5c..d080baa 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -15,6 +15,7 @@ import ( ) const ( + headerLen = 8 maxFragmentSize = 65536 ) @@ -92,20 +93,26 @@ func createPacket(pktType uint16, data []byte) (packet []byte) { return buf.Bytes() } -// readHeader parses a packet and verifies its reported size +// readHeader parses a packet and verifies its reported size. The wire +// layout is [type:2][reserved:2][size:4][body...] where size counts the +// 8-byte header plus the body. handleMsgFrame caps the reassembled body +// at maxFragmentSize, so any size outside [headerLen, headerLen+maxFragmentSize] +// is invalid and rejected before slicing. func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { - // header needs to be 8 min - if len(data) < 8 { + if len(data) < headerLen { return 0, 0, nil, errors.New("header too short, fragment likely") } r := bytes.NewReader(data) binary.Read(r, binary.LittleEndian, &packetType) r.Seek(4, io.SeekStart) binary.Read(r, binary.LittleEndian, &size) - if len(data) < int(size) { - return packetType, size, data[8:], errors.New("data incomplete, fragment received") + if size < headerLen || size-headerLen > maxFragmentSize { + return packetType, size, nil, fmt.Errorf("invalid declared size %d", size) } - return packetType, size, data[8:size], nil + if len(data) < int(size) { + return packetType, size, data[headerLen:], errors.New("data incomplete, fragment received") + } + return packetType, size, data[headerLen:size], nil } // forwards data from a Connection to Transport and wraps it in the rdpgw protocol diff --git a/cmd/rdpgw/protocol/common_test.go b/cmd/rdpgw/protocol/common_test.go index da97275..49e7e9d 100644 --- a/cmd/rdpgw/protocol/common_test.go +++ b/cmd/rdpgw/protocol/common_test.go @@ -1,6 +1,7 @@ package protocol import ( + "encoding/binary" "fmt" "math/rand" "sync" @@ -263,6 +264,30 @@ func TestFragmentTooLarge(t *testing.T) { assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error()) } +// TestReadHeaderRejectsUndersizedSize checks that readHeader rejects a +// declared size smaller than the 8-byte header itself. The reported size is +// taken from the wire, so an attacker-supplied frame with size in [0,7] must +// surface as an error rather than slicing data[8:size] with high < low. +func TestReadHeaderRejectsUndersizedSize(t *testing.T) { + for _, size := range []uint32{0, 1, 2, 7} { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint16(buf[0:2], 1) + binary.LittleEndian.PutUint32(buf[4:8], size) + + defer func() { + if r := recover(); r != nil { + t.Errorf("readHeader panicked on size=%d: %v", size, r) + } + }() + _, _, _, err := readHeader(buf) + if err == nil { + t.Errorf("readHeader returned no error for invalid size=%d", size) + } + }) + } +} + // TestFragmentWithMultiMessage the first message is fragmented, // while the second message is found whole in the final packet func TestFragmentWithMultiMessage(t *testing.T) { diff --git a/cmd/rdpgw/protocol/track.go b/cmd/rdpgw/protocol/track.go index 250a35b..0d38d43 100644 --- a/cmd/rdpgw/protocol/track.go +++ b/cmd/rdpgw/protocol/track.go @@ -1,8 +1,14 @@ package protocol -import "fmt" +import ( + "fmt" + "sync" +) -var Connections map[string]*Monitor +var ( + Connections = map[string]*Monitor{} + connectionsMu sync.RWMutex +) type Monitor struct { Processor *Processor @@ -14,9 +20,8 @@ const ( ) func RegisterTunnel(t *Tunnel, p *Processor) { - if Connections == nil { - Connections = make(map[string]*Monitor) - } + connectionsMu.Lock() + defer connectionsMu.Unlock() Connections[t.Id] = &Monitor{ Processor: p, @@ -25,40 +30,20 @@ func RegisterTunnel(t *Tunnel, p *Processor) { } func RemoveTunnel(t *Tunnel) { + connectionsMu.Lock() + defer connectionsMu.Unlock() + delete(Connections, t.Id) } func Disconnect(id string) error { - if Connections == nil { + connectionsMu.RLock() + m, ok := Connections[id] + connectionsMu.RUnlock() + + if !ok { return fmt.Errorf("%s connection does not exist", id) } - - if m, ok := Connections[id]; !ok { - m.Processor.ctl <- ctlDisconnect - return nil - } - - return fmt.Errorf("%s connection does not exist", id) + m.Processor.ctl <- ctlDisconnect + return nil } - -// CalculateSpeedPerSecond calculate moving average. -/* -func CalculateSpeedPerSecond(connId string) (in int, out int) { - now := time.Now().UnixMilli() - - c := Connections[connId] - total := int64(0) - for _, v := range c.Tunnel.BytesReceived { - total += v - } - in = int(total / (now - c.TimeStamp) * 1000) - - total = int64(0) - for _, v := range c.BytesSent { - total += v - } - out = int(total / (now - c.TimeStamp)) - - return in, out -} -*/ diff --git a/cmd/rdpgw/protocol/track_test.go b/cmd/rdpgw/protocol/track_test.go new file mode 100644 index 0000000..6ac1d47 --- /dev/null +++ b/cmd/rdpgw/protocol/track_test.go @@ -0,0 +1,87 @@ +package protocol + +import ( + "fmt" + "sync" + "testing" + "time" +) + +// resetConnections clears the global registry under its lock. Tests must +// not reassign the package-level Connections variable, since the lock +// protects the map's contents but not the variable's storage; concurrent +// goroutines from other tests can otherwise read a stale reference. +func resetConnections() { + connectionsMu.Lock() + clear(Connections) + connectionsMu.Unlock() +} + +// TestTunnelTrackerConcurrent hammers the global tunnel registry from many +// goroutines. The package's RegisterTunnel/RemoveTunnel write to a shared +// map, so the registry must serialize access. A concurrent write would be +// caught by the Go runtime as a fatal `concurrent map writes` (or by the +// race detector under -race). +func TestTunnelTrackerConcurrent(t *testing.T) { + resetConnections() + t.Cleanup(resetConnections) + + const goroutines = 200 + var wg sync.WaitGroup + wg.Add(goroutines * 2) + for i := 0; i < goroutines; i++ { + i := i + go func() { + defer wg.Done() + id := fmt.Sprintf("tun-%d", i) + RegisterTunnel(&Tunnel{Id: id}, &Processor{ctl: make(chan int)}) + }() + go func() { + defer wg.Done() + id := fmt.Sprintf("tun-%d", i) + RemoveTunnel(&Tunnel{Id: id}) + }() + } + wg.Wait() +} + +// TestDisconnectKnownConnection verifies that Disconnect signals the +// processor for a connection that exists in the registry. +func TestDisconnectKnownConnection(t *testing.T) { + resetConnections() + t.Cleanup(resetConnections) + + p := &Processor{ctl: make(chan int, 1)} + connectionsMu.Lock() + Connections["known"] = &Monitor{Processor: p} + connectionsMu.Unlock() + + if err := Disconnect("known"); err != nil { + t.Fatalf("Disconnect on known id returned err: %v", err) + } + select { + case v := <-p.ctl: + if v != ctlDisconnect { + t.Errorf("ctl received %d, want ctlDisconnect=%d", v, ctlDisconnect) + } + case <-time.After(100 * time.Millisecond): + t.Error("Disconnect did not signal ctlDisconnect on the processor channel") + } +} + +// TestDisconnectMissingConnectionDoesNotPanic verifies that Disconnect on +// an id that is not in the registry returns an error rather than +// dereferencing a nil Monitor. +func TestDisconnectMissingConnectionDoesNotPanic(t *testing.T) { + resetConnections() + t.Cleanup(resetConnections) + + defer func() { + if r := recover(); r != nil { + t.Errorf("Disconnect panicked on missing id: %v", r) + } + }() + if err := Disconnect("nonexistent"); err == nil { + t.Error("Disconnect on missing id returned no error") + } +} diff --git a/cmd/rdpgw/web/ntlm.go b/cmd/rdpgw/web/ntlm.go index b4fc89d..3b7b8ac 100644 --- a/cmd/rdpgw/web/ntlm.go +++ b/cmd/rdpgw/web/ntlm.go @@ -2,15 +2,17 @@ package web import ( "context" - "errors" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" + "errors" + "log" + "net" + "net/http" + "strings" + "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/shared/auth" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "log" - "net" - "net/http" - "time" ) type ntlmAuthMode uint32 @@ -47,13 +49,15 @@ func (h *NTLMAuthHandler) NTLMAuth(next http.HandlerFunc) http.HandlerFunc { } } -func (h *NTLMAuthHandler) getAuthPayload (r *http.Request) (payload string, authMode ntlmAuthMode, err error) { +func (h *NTLMAuthHandler) getAuthPayload(r *http.Request) (payload string, authMode ntlmAuthMode, err error) { authorisationEncoded := r.Header.Get("Authorization") - if authorisationEncoded[0:5] == "NTLM " { - return authorisationEncoded[5:], authNTLM, nil + const ntlmPrefix = "NTLM " + const negotiatePrefix = "Negotiate " + if strings.HasPrefix(authorisationEncoded, ntlmPrefix) { + return authorisationEncoded[len(ntlmPrefix):], authNTLM, nil } - if authorisationEncoded[0:10] == "Negotiate " { - return authorisationEncoded[10:], authNegotiate, nil + if strings.HasPrefix(authorisationEncoded, negotiatePrefix) { + return authorisationEncoded[len(negotiatePrefix):], authNegotiate, nil } return "", authNone, errors.New("Invalid NTLM Authorisation header") } diff --git a/cmd/rdpgw/web/ntlm_test.go b/cmd/rdpgw/web/ntlm_test.go new file mode 100644 index 0000000..ed1f4b6 --- /dev/null +++ b/cmd/rdpgw/web/ntlm_test.go @@ -0,0 +1,45 @@ +package web + +import ( + "net/http/httptest" + "testing" +) + +// TestGetAuthPayloadShortHeader asserts that getAuthPayload tolerates +// Authorization header values shorter than the prefixes it tests for. The +// header value is attacker-controlled, so a sub-5-byte value must surface +// as an error rather than a slice-bounds panic that crashes the worker. +func TestGetAuthPayloadShortHeader(t *testing.T) { + cases := []struct { + name string + header string + set bool + }{ + {"missing", "", false}, + {"empty", "", true}, + {"three chars", "abc", true}, + {"four chars", "NTLM", true}, + {"five chars not NTLM prefix", "abcde", true}, + {"nine chars Negotiate without trailing space", "Negotiate", true}, + } + + h := &NTLMAuthHandler{} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tc.set { + req.Header.Set("Authorization", tc.header) + } + + defer func() { + if r := recover(); r != nil { + t.Errorf("getAuthPayload panicked on Authorization=%q: %v", tc.header, r) + } + }() + _, _, err := h.getAuthPayload(req) + if err == nil { + t.Errorf("expected error for Authorization=%q, got nil", tc.header) + } + }) + } +}