mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-05-13 03:40:01 +00:00
Harden three small panic / race paths (#186)
Three independent corrections to functions that crash the worker on malformed or concurrent input: * protocol/common.readHeader: validate the on-wire size field before slicing data[8:size]. Introduce a named headerLen constant and reject declared sizes outside [headerLen, headerLen+maxFragmentSize], so a size of 0..7 (which previously panicked with slice bounds out of range) and an oversized size both surface as an error instead. * protocol/track: serialize access to the global Connections map with a sync.RWMutex. Concurrent RegisterTunnel/RemoveTunnel calls would otherwise be caught by the runtime as a fatal `concurrent map writes`. Also correct the inverted condition in Disconnect (the previous code dereferenced a nil Monitor when the id was missing and returned "does not exist" when the id was present). * web/ntlm.getAuthPayload: switch from authorisationEncoded[0:5] / [0:10] to strings.HasPrefix so an Authorization header shorter than the prefixes returns an error instead of a slice-bounds panic. Adds: - TestReadHeaderRejectsUndersizedSize (sizes 0/1/2/7). - TestTunnelTrackerConcurrent (200 goroutine pairs). - TestDisconnectKnownConnection / TestDisconnectMissingConnectionDoesNotPanic. - TestGetAuthPayloadShortHeader (missing/empty/3/4/5/9 character values).
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
headerLen = 8
|
||||
maxFragmentSize = 65536
|
||||
)
|
||||
|
||||
@@ -92,20 +93,26 @@ func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// readHeader parses a packet and verifies its reported size
|
||||
// readHeader parses a packet and verifies its reported size. The wire
|
||||
// layout is [type:2][reserved:2][size:4][body...] where size counts the
|
||||
// 8-byte header plus the body. handleMsgFrame caps the reassembled body
|
||||
// at maxFragmentSize, so any size outside [headerLen, headerLen+maxFragmentSize]
|
||||
// is invalid and rejected before slicing.
|
||||
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
||||
// header needs to be 8 min
|
||||
if len(data) < 8 {
|
||||
if len(data) < headerLen {
|
||||
return 0, 0, nil, errors.New("header too short, fragment likely")
|
||||
}
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &packetType)
|
||||
r.Seek(4, io.SeekStart)
|
||||
binary.Read(r, binary.LittleEndian, &size)
|
||||
if len(data) < int(size) {
|
||||
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
|
||||
if size < headerLen || size-headerLen > maxFragmentSize {
|
||||
return packetType, size, nil, fmt.Errorf("invalid declared size %d", size)
|
||||
}
|
||||
return packetType, size, data[8:size], nil
|
||||
if len(data) < int(size) {
|
||||
return packetType, size, data[headerLen:], errors.New("data incomplete, fragment received")
|
||||
}
|
||||
return packetType, size, data[headerLen:size], nil
|
||||
}
|
||||
|
||||
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
@@ -263,6 +264,30 @@ func TestFragmentTooLarge(t *testing.T) {
|
||||
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
|
||||
}
|
||||
|
||||
// TestReadHeaderRejectsUndersizedSize checks that readHeader rejects a
|
||||
// declared size smaller than the 8-byte header itself. The reported size is
|
||||
// taken from the wire, so an attacker-supplied frame with size in [0,7] must
|
||||
// surface as an error rather than slicing data[8:size] with high < low.
|
||||
func TestReadHeaderRejectsUndersizedSize(t *testing.T) {
|
||||
for _, size := range []uint32{0, 1, 2, 7} {
|
||||
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
|
||||
buf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint16(buf[0:2], 1)
|
||||
binary.LittleEndian.PutUint32(buf[4:8], size)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("readHeader panicked on size=%d: %v", size, r)
|
||||
}
|
||||
}()
|
||||
_, _, _, err := readHeader(buf)
|
||||
if err == nil {
|
||||
t.Errorf("readHeader returned no error for invalid size=%d", size)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFragmentWithMultiMessage the first message is fragmented,
|
||||
// while the second message is found whole in the final packet
|
||||
func TestFragmentWithMultiMessage(t *testing.T) {
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package protocol
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var Connections map[string]*Monitor
|
||||
var (
|
||||
Connections = map[string]*Monitor{}
|
||||
connectionsMu sync.RWMutex
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
Processor *Processor
|
||||
@@ -14,9 +20,8 @@ const (
|
||||
)
|
||||
|
||||
func RegisterTunnel(t *Tunnel, p *Processor) {
|
||||
if Connections == nil {
|
||||
Connections = make(map[string]*Monitor)
|
||||
}
|
||||
connectionsMu.Lock()
|
||||
defer connectionsMu.Unlock()
|
||||
|
||||
Connections[t.Id] = &Monitor{
|
||||
Processor: p,
|
||||
@@ -25,40 +30,20 @@ func RegisterTunnel(t *Tunnel, p *Processor) {
|
||||
}
|
||||
|
||||
func RemoveTunnel(t *Tunnel) {
|
||||
connectionsMu.Lock()
|
||||
defer connectionsMu.Unlock()
|
||||
|
||||
delete(Connections, t.Id)
|
||||
}
|
||||
|
||||
func Disconnect(id string) error {
|
||||
if Connections == nil {
|
||||
connectionsMu.RLock()
|
||||
m, ok := Connections[id]
|
||||
connectionsMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("%s connection does not exist", id)
|
||||
}
|
||||
|
||||
if m, ok := Connections[id]; !ok {
|
||||
m.Processor.ctl <- ctlDisconnect
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s connection does not exist", id)
|
||||
m.Processor.ctl <- ctlDisconnect
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateSpeedPerSecond calculate moving average.
|
||||
/*
|
||||
func CalculateSpeedPerSecond(connId string) (in int, out int) {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
c := Connections[connId]
|
||||
total := int64(0)
|
||||
for _, v := range c.Tunnel.BytesReceived {
|
||||
total += v
|
||||
}
|
||||
in = int(total / (now - c.TimeStamp) * 1000)
|
||||
|
||||
total = int64(0)
|
||||
for _, v := range c.BytesSent {
|
||||
total += v
|
||||
}
|
||||
out = int(total / (now - c.TimeStamp))
|
||||
|
||||
return in, out
|
||||
}
|
||||
*/
|
||||
|
||||
87
cmd/rdpgw/protocol/track_test.go
Normal file
87
cmd/rdpgw/protocol/track_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// resetConnections clears the global registry under its lock. Tests must
|
||||
// not reassign the package-level Connections variable, since the lock
|
||||
// protects the map's contents but not the variable's storage; concurrent
|
||||
// goroutines from other tests can otherwise read a stale reference.
|
||||
func resetConnections() {
|
||||
connectionsMu.Lock()
|
||||
clear(Connections)
|
||||
connectionsMu.Unlock()
|
||||
}
|
||||
|
||||
// TestTunnelTrackerConcurrent hammers the global tunnel registry from many
|
||||
// goroutines. The package's RegisterTunnel/RemoveTunnel write to a shared
|
||||
// map, so the registry must serialize access. A concurrent write would be
|
||||
// caught by the Go runtime as a fatal `concurrent map writes` (or by the
|
||||
// race detector under -race).
|
||||
func TestTunnelTrackerConcurrent(t *testing.T) {
|
||||
resetConnections()
|
||||
t.Cleanup(resetConnections)
|
||||
|
||||
const goroutines = 200
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 2)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
id := fmt.Sprintf("tun-%d", i)
|
||||
RegisterTunnel(&Tunnel{Id: id}, &Processor{ctl: make(chan int)})
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
id := fmt.Sprintf("tun-%d", i)
|
||||
RemoveTunnel(&Tunnel{Id: id})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestDisconnectKnownConnection verifies that Disconnect signals the
|
||||
// processor for a connection that exists in the registry.
|
||||
func TestDisconnectKnownConnection(t *testing.T) {
|
||||
resetConnections()
|
||||
t.Cleanup(resetConnections)
|
||||
|
||||
p := &Processor{ctl: make(chan int, 1)}
|
||||
connectionsMu.Lock()
|
||||
Connections["known"] = &Monitor{Processor: p}
|
||||
connectionsMu.Unlock()
|
||||
|
||||
if err := Disconnect("known"); err != nil {
|
||||
t.Fatalf("Disconnect on known id returned err: %v", err)
|
||||
}
|
||||
select {
|
||||
case v := <-p.ctl:
|
||||
if v != ctlDisconnect {
|
||||
t.Errorf("ctl received %d, want ctlDisconnect=%d", v, ctlDisconnect)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Disconnect did not signal ctlDisconnect on the processor channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDisconnectMissingConnectionDoesNotPanic verifies that Disconnect on
|
||||
// an id that is not in the registry returns an error rather than
|
||||
// dereferencing a nil Monitor.
|
||||
func TestDisconnectMissingConnectionDoesNotPanic(t *testing.T) {
|
||||
resetConnections()
|
||||
t.Cleanup(resetConnections)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Disconnect panicked on missing id: %v", r)
|
||||
}
|
||||
}()
|
||||
if err := Disconnect("nonexistent"); err == nil {
|
||||
t.Error("Disconnect on missing id returned no error")
|
||||
}
|
||||
}
|
||||
@@ -2,15 +2,17 @@ package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ntlmAuthMode uint32
|
||||
@@ -47,13 +49,15 @@ func (h *NTLMAuthHandler) NTLMAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *NTLMAuthHandler) getAuthPayload (r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
|
||||
func (h *NTLMAuthHandler) getAuthPayload(r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
|
||||
authorisationEncoded := r.Header.Get("Authorization")
|
||||
if authorisationEncoded[0:5] == "NTLM " {
|
||||
return authorisationEncoded[5:], authNTLM, nil
|
||||
const ntlmPrefix = "NTLM "
|
||||
const negotiatePrefix = "Negotiate "
|
||||
if strings.HasPrefix(authorisationEncoded, ntlmPrefix) {
|
||||
return authorisationEncoded[len(ntlmPrefix):], authNTLM, nil
|
||||
}
|
||||
if authorisationEncoded[0:10] == "Negotiate " {
|
||||
return authorisationEncoded[10:], authNegotiate, nil
|
||||
if strings.HasPrefix(authorisationEncoded, negotiatePrefix) {
|
||||
return authorisationEncoded[len(negotiatePrefix):], authNegotiate, nil
|
||||
}
|
||||
return "", authNone, errors.New("Invalid NTLM Authorisation header")
|
||||
}
|
||||
|
||||
45
cmd/rdpgw/web/ntlm_test.go
Normal file
45
cmd/rdpgw/web/ntlm_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetAuthPayloadShortHeader asserts that getAuthPayload tolerates
|
||||
// Authorization header values shorter than the prefixes it tests for. The
|
||||
// header value is attacker-controlled, so a sub-5-byte value must surface
|
||||
// as an error rather than a slice-bounds panic that crashes the worker.
|
||||
func TestGetAuthPayloadShortHeader(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
header string
|
||||
set bool
|
||||
}{
|
||||
{"missing", "", false},
|
||||
{"empty", "", true},
|
||||
{"three chars", "abc", true},
|
||||
{"four chars", "NTLM", true},
|
||||
{"five chars not NTLM prefix", "abcde", true},
|
||||
{"nine chars Negotiate without trailing space", "Negotiate", true},
|
||||
}
|
||||
|
||||
h := &NTLMAuthHandler{}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.set {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("getAuthPayload panicked on Authorization=%q: %v", tc.header, r)
|
||||
}
|
||||
}()
|
||||
_, _, err := h.getAuthPayload(req)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for Authorization=%q, got nil", tc.header)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user