Harden three small panic / race paths (#186)

Three independent corrections to functions that crash the worker on
malformed or concurrent input:

* protocol/common.readHeader: validate the on-wire size field before
  slicing data[8:size]. Introduce a named headerLen constant and reject
  declared sizes outside [headerLen, headerLen+maxFragmentSize], so a
  size of 0..7 (which previously panicked with slice bounds out of
  range) and an oversized size both surface as an error instead.

* protocol/track: serialize access to the global Connections map with
  a sync.RWMutex. Concurrent RegisterTunnel/RemoveTunnel calls would
  otherwise be caught by the runtime as a fatal `concurrent map
  writes`. Also correct the inverted condition in Disconnect (the
  previous code dereferenced a nil Monitor when the id was missing and
  returned "does not exist" when the id was present).

* web/ntlm.getAuthPayload: switch from authorisationEncoded[0:5] /
  [0:10] to strings.HasPrefix so an Authorization header shorter than
  the prefixes returns an error instead of a slice-bounds panic.

Adds:
- TestReadHeaderRejectsUndersizedSize (sizes 0/1/2/7).
- TestTunnelTrackerConcurrent (200 goroutine pairs).
- TestDisconnectKnownConnection / TestDisconnectMissingConnectionDoesNotPanic.
- TestGetAuthPayloadShortHeader (missing/empty/3/4/5/9 character values).
This commit is contained in:
bolkedebruin
2026-04-30 14:38:36 +02:00
committed by GitHub
parent 628046b34c
commit 8fc5677dfc
6 changed files with 205 additions and 52 deletions

View File

@@ -15,6 +15,7 @@ import (
)
const (
headerLen = 8
maxFragmentSize = 65536
)
@@ -92,20 +93,26 @@ func createPacket(pktType uint16, data []byte) (packet []byte) {
return buf.Bytes()
}
// readHeader parses a packet and verifies its reported size
// readHeader parses a packet and verifies its reported size. The wire
// layout is [type:2][reserved:2][size:4][body...] where size counts the
// 8-byte header plus the body. handleMsgFrame caps the reassembled body
// at maxFragmentSize, so any size outside [headerLen, headerLen+maxFragmentSize]
// is invalid and rejected before slicing.
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
if len(data) < headerLen {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
if size < headerLen || size-headerLen > maxFragmentSize {
return packetType, size, nil, fmt.Errorf("invalid declared size %d", size)
}
return packetType, size, data[8:size], nil
if len(data) < int(size) {
return packetType, size, data[headerLen:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[headerLen:size], nil
}
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol

View File

@@ -1,6 +1,7 @@
package protocol
import (
"encoding/binary"
"fmt"
"math/rand"
"sync"
@@ -263,6 +264,30 @@ func TestFragmentTooLarge(t *testing.T) {
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
}
// TestReadHeaderRejectsUndersizedSize checks that readHeader rejects a
// declared size smaller than the 8-byte header itself. The reported size is
// taken from the wire, so an attacker-supplied frame with size in [0,7] must
// surface as an error rather than slicing data[8:size] with high < low.
func TestReadHeaderRejectsUndersizedSize(t *testing.T) {
for _, size := range []uint32{0, 1, 2, 7} {
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
buf := make([]byte, 8)
binary.LittleEndian.PutUint16(buf[0:2], 1)
binary.LittleEndian.PutUint32(buf[4:8], size)
defer func() {
if r := recover(); r != nil {
t.Errorf("readHeader panicked on size=%d: %v", size, r)
}
}()
_, _, _, err := readHeader(buf)
if err == nil {
t.Errorf("readHeader returned no error for invalid size=%d", size)
}
})
}
}
// TestFragmentWithMultiMessage the first message is fragmented,
// while the second message is found whole in the final packet
func TestFragmentWithMultiMessage(t *testing.T) {

View File

@@ -1,8 +1,14 @@
package protocol
import "fmt"
import (
"fmt"
"sync"
)
var Connections map[string]*Monitor
var (
Connections = map[string]*Monitor{}
connectionsMu sync.RWMutex
)
type Monitor struct {
Processor *Processor
@@ -14,9 +20,8 @@ const (
)
func RegisterTunnel(t *Tunnel, p *Processor) {
if Connections == nil {
Connections = make(map[string]*Monitor)
}
connectionsMu.Lock()
defer connectionsMu.Unlock()
Connections[t.Id] = &Monitor{
Processor: p,
@@ -25,40 +30,20 @@ func RegisterTunnel(t *Tunnel, p *Processor) {
}
func RemoveTunnel(t *Tunnel) {
connectionsMu.Lock()
defer connectionsMu.Unlock()
delete(Connections, t.Id)
}
func Disconnect(id string) error {
if Connections == nil {
connectionsMu.RLock()
m, ok := Connections[id]
connectionsMu.RUnlock()
if !ok {
return fmt.Errorf("%s connection does not exist", id)
}
if m, ok := Connections[id]; !ok {
m.Processor.ctl <- ctlDisconnect
return nil
}
return fmt.Errorf("%s connection does not exist", id)
m.Processor.ctl <- ctlDisconnect
return nil
}
// CalculateSpeedPerSecond calculate moving average.
/*
func CalculateSpeedPerSecond(connId string) (in int, out int) {
now := time.Now().UnixMilli()
c := Connections[connId]
total := int64(0)
for _, v := range c.Tunnel.BytesReceived {
total += v
}
in = int(total / (now - c.TimeStamp) * 1000)
total = int64(0)
for _, v := range c.BytesSent {
total += v
}
out = int(total / (now - c.TimeStamp))
return in, out
}
*/

View File

@@ -0,0 +1,87 @@
package protocol
import (
"fmt"
"sync"
"testing"
"time"
)
// resetConnections clears the global registry under its lock. Tests must
// not reassign the package-level Connections variable, since the lock
// protects the map's contents but not the variable's storage; concurrent
// goroutines from other tests can otherwise read a stale reference.
func resetConnections() {
connectionsMu.Lock()
clear(Connections)
connectionsMu.Unlock()
}
// TestTunnelTrackerConcurrent hammers the global tunnel registry from many
// goroutines. The package's RegisterTunnel/RemoveTunnel write to a shared
// map, so the registry must serialize access. A concurrent write would be
// caught by the Go runtime as a fatal `concurrent map writes` (or by the
// race detector under -race).
func TestTunnelTrackerConcurrent(t *testing.T) {
resetConnections()
t.Cleanup(resetConnections)
const goroutines = 200
var wg sync.WaitGroup
wg.Add(goroutines * 2)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
id := fmt.Sprintf("tun-%d", i)
RegisterTunnel(&Tunnel{Id: id}, &Processor{ctl: make(chan int)})
}()
go func() {
defer wg.Done()
id := fmt.Sprintf("tun-%d", i)
RemoveTunnel(&Tunnel{Id: id})
}()
}
wg.Wait()
}
// TestDisconnectKnownConnection verifies that Disconnect signals the
// processor for a connection that exists in the registry.
func TestDisconnectKnownConnection(t *testing.T) {
resetConnections()
t.Cleanup(resetConnections)
p := &Processor{ctl: make(chan int, 1)}
connectionsMu.Lock()
Connections["known"] = &Monitor{Processor: p}
connectionsMu.Unlock()
if err := Disconnect("known"); err != nil {
t.Fatalf("Disconnect on known id returned err: %v", err)
}
select {
case v := <-p.ctl:
if v != ctlDisconnect {
t.Errorf("ctl received %d, want ctlDisconnect=%d", v, ctlDisconnect)
}
case <-time.After(100 * time.Millisecond):
t.Error("Disconnect did not signal ctlDisconnect on the processor channel")
}
}
// TestDisconnectMissingConnectionDoesNotPanic verifies that Disconnect on
// an id that is not in the registry returns an error rather than
// dereferencing a nil Monitor.
func TestDisconnectMissingConnectionDoesNotPanic(t *testing.T) {
resetConnections()
t.Cleanup(resetConnections)
defer func() {
if r := recover(); r != nil {
t.Errorf("Disconnect panicked on missing id: %v", r)
}
}()
if err := Disconnect("nonexistent"); err == nil {
t.Error("Disconnect on missing id returned no error")
}
}

View File

@@ -2,15 +2,17 @@ package web
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"errors"
"log"
"net"
"net/http"
"strings"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"log"
"net"
"net/http"
"time"
)
type ntlmAuthMode uint32
@@ -47,13 +49,15 @@ func (h *NTLMAuthHandler) NTLMAuth(next http.HandlerFunc) http.HandlerFunc {
}
}
func (h *NTLMAuthHandler) getAuthPayload (r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
func (h *NTLMAuthHandler) getAuthPayload(r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
authorisationEncoded := r.Header.Get("Authorization")
if authorisationEncoded[0:5] == "NTLM " {
return authorisationEncoded[5:], authNTLM, nil
const ntlmPrefix = "NTLM "
const negotiatePrefix = "Negotiate "
if strings.HasPrefix(authorisationEncoded, ntlmPrefix) {
return authorisationEncoded[len(ntlmPrefix):], authNTLM, nil
}
if authorisationEncoded[0:10] == "Negotiate " {
return authorisationEncoded[10:], authNegotiate, nil
if strings.HasPrefix(authorisationEncoded, negotiatePrefix) {
return authorisationEncoded[len(negotiatePrefix):], authNegotiate, nil
}
return "", authNone, errors.New("Invalid NTLM Authorisation header")
}

View File

@@ -0,0 +1,45 @@
package web
import (
"net/http/httptest"
"testing"
)
// TestGetAuthPayloadShortHeader asserts that getAuthPayload tolerates
// Authorization header values shorter than the prefixes it tests for. The
// header value is attacker-controlled, so a sub-5-byte value must surface
// as an error rather than a slice-bounds panic that crashes the worker.
func TestGetAuthPayloadShortHeader(t *testing.T) {
cases := []struct {
name string
header string
set bool
}{
{"missing", "", false},
{"empty", "", true},
{"three chars", "abc", true},
{"four chars", "NTLM", true},
{"five chars not NTLM prefix", "abcde", true},
{"nine chars Negotiate without trailing space", "Negotiate", true},
}
h := &NTLMAuthHandler{}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tc.set {
req.Header.Set("Authorization", tc.header)
}
defer func() {
if r := recover(); r != nil {
t.Errorf("getAuthPayload panicked on Authorization=%q: %v", tc.header, r)
}
}()
_, _, err := h.getAuthPayload(req)
if err == nil {
t.Errorf("expected error for Authorization=%q, got nil", tc.header)
}
})
}
}