mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-05-17 13:50:00 +00:00
Harden three small panic / race paths (#186)
Three independent corrections to functions that crash the worker on malformed or concurrent input: * protocol/common.readHeader: validate the on-wire size field before slicing data[8:size]. Introduce a named headerLen constant and reject declared sizes outside [headerLen, headerLen+maxFragmentSize], so a size of 0..7 (which previously panicked with slice bounds out of range) and an oversized size both surface as an error instead. * protocol/track: serialize access to the global Connections map with a sync.RWMutex. Concurrent RegisterTunnel/RemoveTunnel calls would otherwise be caught by the runtime as a fatal `concurrent map writes`. Also correct the inverted condition in Disconnect (the previous code dereferenced a nil Monitor when the id was missing and returned "does not exist" when the id was present). * web/ntlm.getAuthPayload: switch from authorisationEncoded[0:5] / [0:10] to strings.HasPrefix so an Authorization header shorter than the prefixes returns an error instead of a slice-bounds panic. Adds: - TestReadHeaderRejectsUndersizedSize (sizes 0/1/2/7). - TestTunnelTrackerConcurrent (200 goroutine pairs). - TestDisconnectKnownConnection / TestDisconnectMissingConnectionDoesNotPanic. - TestGetAuthPayloadShortHeader (missing/empty/3/4/5/9 character values).
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
headerLen = 8
|
||||||
maxFragmentSize = 65536
|
maxFragmentSize = 65536
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,20 +93,26 @@ func createPacket(pktType uint16, data []byte) (packet []byte) {
|
|||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// readHeader parses a packet and verifies its reported size
|
// readHeader parses a packet and verifies its reported size. The wire
|
||||||
|
// layout is [type:2][reserved:2][size:4][body...] where size counts the
|
||||||
|
// 8-byte header plus the body. handleMsgFrame caps the reassembled body
|
||||||
|
// at maxFragmentSize, so any size outside [headerLen, headerLen+maxFragmentSize]
|
||||||
|
// is invalid and rejected before slicing.
|
||||||
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
||||||
// header needs to be 8 min
|
if len(data) < headerLen {
|
||||||
if len(data) < 8 {
|
|
||||||
return 0, 0, nil, errors.New("header too short, fragment likely")
|
return 0, 0, nil, errors.New("header too short, fragment likely")
|
||||||
}
|
}
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
binary.Read(r, binary.LittleEndian, &packetType)
|
binary.Read(r, binary.LittleEndian, &packetType)
|
||||||
r.Seek(4, io.SeekStart)
|
r.Seek(4, io.SeekStart)
|
||||||
binary.Read(r, binary.LittleEndian, &size)
|
binary.Read(r, binary.LittleEndian, &size)
|
||||||
if len(data) < int(size) {
|
if size < headerLen || size-headerLen > maxFragmentSize {
|
||||||
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
|
return packetType, size, nil, fmt.Errorf("invalid declared size %d", size)
|
||||||
}
|
}
|
||||||
return packetType, size, data[8:size], nil
|
if len(data) < int(size) {
|
||||||
|
return packetType, size, data[headerLen:], errors.New("data incomplete, fragment received")
|
||||||
|
}
|
||||||
|
return packetType, size, data[headerLen:size], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -263,6 +264,30 @@ func TestFragmentTooLarge(t *testing.T) {
|
|||||||
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
|
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestReadHeaderRejectsUndersizedSize checks that readHeader rejects a
|
||||||
|
// declared size smaller than the 8-byte header itself. The reported size is
|
||||||
|
// taken from the wire, so an attacker-supplied frame with size in [0,7] must
|
||||||
|
// surface as an error rather than slicing data[8:size] with high < low.
|
||||||
|
func TestReadHeaderRejectsUndersizedSize(t *testing.T) {
|
||||||
|
for _, size := range []uint32{0, 1, 2, 7} {
|
||||||
|
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
binary.LittleEndian.PutUint16(buf[0:2], 1)
|
||||||
|
binary.LittleEndian.PutUint32(buf[4:8], size)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("readHeader panicked on size=%d: %v", size, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
_, _, _, err := readHeader(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("readHeader returned no error for invalid size=%d", size)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestFragmentWithMultiMessage the first message is fragmented,
|
// TestFragmentWithMultiMessage the first message is fragmented,
|
||||||
// while the second message is found whole in the final packet
|
// while the second message is found whole in the final packet
|
||||||
func TestFragmentWithMultiMessage(t *testing.T) {
|
func TestFragmentWithMultiMessage(t *testing.T) {
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
var Connections map[string]*Monitor
|
var (
|
||||||
|
Connections = map[string]*Monitor{}
|
||||||
|
connectionsMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
type Monitor struct {
|
type Monitor struct {
|
||||||
Processor *Processor
|
Processor *Processor
|
||||||
@@ -14,9 +20,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func RegisterTunnel(t *Tunnel, p *Processor) {
|
func RegisterTunnel(t *Tunnel, p *Processor) {
|
||||||
if Connections == nil {
|
connectionsMu.Lock()
|
||||||
Connections = make(map[string]*Monitor)
|
defer connectionsMu.Unlock()
|
||||||
}
|
|
||||||
|
|
||||||
Connections[t.Id] = &Monitor{
|
Connections[t.Id] = &Monitor{
|
||||||
Processor: p,
|
Processor: p,
|
||||||
@@ -25,40 +30,20 @@ func RegisterTunnel(t *Tunnel, p *Processor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RemoveTunnel(t *Tunnel) {
|
func RemoveTunnel(t *Tunnel) {
|
||||||
|
connectionsMu.Lock()
|
||||||
|
defer connectionsMu.Unlock()
|
||||||
|
|
||||||
delete(Connections, t.Id)
|
delete(Connections, t.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Disconnect(id string) error {
|
func Disconnect(id string) error {
|
||||||
if Connections == nil {
|
connectionsMu.RLock()
|
||||||
|
m, ok := Connections[id]
|
||||||
|
connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("%s connection does not exist", id)
|
return fmt.Errorf("%s connection does not exist", id)
|
||||||
}
|
}
|
||||||
|
m.Processor.ctl <- ctlDisconnect
|
||||||
if m, ok := Connections[id]; !ok {
|
return nil
|
||||||
m.Processor.ctl <- ctlDisconnect
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("%s connection does not exist", id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateSpeedPerSecond calculate moving average.
|
|
||||||
/*
|
|
||||||
func CalculateSpeedPerSecond(connId string) (in int, out int) {
|
|
||||||
now := time.Now().UnixMilli()
|
|
||||||
|
|
||||||
c := Connections[connId]
|
|
||||||
total := int64(0)
|
|
||||||
for _, v := range c.Tunnel.BytesReceived {
|
|
||||||
total += v
|
|
||||||
}
|
|
||||||
in = int(total / (now - c.TimeStamp) * 1000)
|
|
||||||
|
|
||||||
total = int64(0)
|
|
||||||
for _, v := range c.BytesSent {
|
|
||||||
total += v
|
|
||||||
}
|
|
||||||
out = int(total / (now - c.TimeStamp))
|
|
||||||
|
|
||||||
return in, out
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|||||||
87
cmd/rdpgw/protocol/track_test.go
Normal file
87
cmd/rdpgw/protocol/track_test.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resetConnections clears the global registry under its lock. Tests must
|
||||||
|
// not reassign the package-level Connections variable, since the lock
|
||||||
|
// protects the map's contents but not the variable's storage; concurrent
|
||||||
|
// goroutines from other tests can otherwise read a stale reference.
|
||||||
|
func resetConnections() {
|
||||||
|
connectionsMu.Lock()
|
||||||
|
clear(Connections)
|
||||||
|
connectionsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTunnelTrackerConcurrent hammers the global tunnel registry from many
|
||||||
|
// goroutines. The package's RegisterTunnel/RemoveTunnel write to a shared
|
||||||
|
// map, so the registry must serialize access. A concurrent write would be
|
||||||
|
// caught by the Go runtime as a fatal `concurrent map writes` (or by the
|
||||||
|
// race detector under -race).
|
||||||
|
func TestTunnelTrackerConcurrent(t *testing.T) {
|
||||||
|
resetConnections()
|
||||||
|
t.Cleanup(resetConnections)
|
||||||
|
|
||||||
|
const goroutines = 200
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines * 2)
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
id := fmt.Sprintf("tun-%d", i)
|
||||||
|
RegisterTunnel(&Tunnel{Id: id}, &Processor{ctl: make(chan int)})
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
id := fmt.Sprintf("tun-%d", i)
|
||||||
|
RemoveTunnel(&Tunnel{Id: id})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDisconnectKnownConnection verifies that Disconnect signals the
|
||||||
|
// processor for a connection that exists in the registry.
|
||||||
|
func TestDisconnectKnownConnection(t *testing.T) {
|
||||||
|
resetConnections()
|
||||||
|
t.Cleanup(resetConnections)
|
||||||
|
|
||||||
|
p := &Processor{ctl: make(chan int, 1)}
|
||||||
|
connectionsMu.Lock()
|
||||||
|
Connections["known"] = &Monitor{Processor: p}
|
||||||
|
connectionsMu.Unlock()
|
||||||
|
|
||||||
|
if err := Disconnect("known"); err != nil {
|
||||||
|
t.Fatalf("Disconnect on known id returned err: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case v := <-p.ctl:
|
||||||
|
if v != ctlDisconnect {
|
||||||
|
t.Errorf("ctl received %d, want ctlDisconnect=%d", v, ctlDisconnect)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Disconnect did not signal ctlDisconnect on the processor channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDisconnectMissingConnectionDoesNotPanic verifies that Disconnect on
|
||||||
|
// an id that is not in the registry returns an error rather than
|
||||||
|
// dereferencing a nil Monitor.
|
||||||
|
func TestDisconnectMissingConnectionDoesNotPanic(t *testing.T) {
|
||||||
|
resetConnections()
|
||||||
|
t.Cleanup(resetConnections)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("Disconnect panicked on missing id: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := Disconnect("nonexistent"); err == nil {
|
||||||
|
t.Error("Disconnect on missing id returned no error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,15 +2,17 @@ package web
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ntlmAuthMode uint32
|
type ntlmAuthMode uint32
|
||||||
@@ -47,13 +49,15 @@ func (h *NTLMAuthHandler) NTLMAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *NTLMAuthHandler) getAuthPayload (r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
|
func (h *NTLMAuthHandler) getAuthPayload(r *http.Request) (payload string, authMode ntlmAuthMode, err error) {
|
||||||
authorisationEncoded := r.Header.Get("Authorization")
|
authorisationEncoded := r.Header.Get("Authorization")
|
||||||
if authorisationEncoded[0:5] == "NTLM " {
|
const ntlmPrefix = "NTLM "
|
||||||
return authorisationEncoded[5:], authNTLM, nil
|
const negotiatePrefix = "Negotiate "
|
||||||
|
if strings.HasPrefix(authorisationEncoded, ntlmPrefix) {
|
||||||
|
return authorisationEncoded[len(ntlmPrefix):], authNTLM, nil
|
||||||
}
|
}
|
||||||
if authorisationEncoded[0:10] == "Negotiate " {
|
if strings.HasPrefix(authorisationEncoded, negotiatePrefix) {
|
||||||
return authorisationEncoded[10:], authNegotiate, nil
|
return authorisationEncoded[len(negotiatePrefix):], authNegotiate, nil
|
||||||
}
|
}
|
||||||
return "", authNone, errors.New("Invalid NTLM Authorisation header")
|
return "", authNone, errors.New("Invalid NTLM Authorisation header")
|
||||||
}
|
}
|
||||||
|
|||||||
45
cmd/rdpgw/web/ntlm_test.go
Normal file
45
cmd/rdpgw/web/ntlm_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetAuthPayloadShortHeader asserts that getAuthPayload tolerates
|
||||||
|
// Authorization header values shorter than the prefixes it tests for. The
|
||||||
|
// header value is attacker-controlled, so a sub-5-byte value must surface
|
||||||
|
// as an error rather than a slice-bounds panic that crashes the worker.
|
||||||
|
func TestGetAuthPayloadShortHeader(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
set bool
|
||||||
|
}{
|
||||||
|
{"missing", "", false},
|
||||||
|
{"empty", "", true},
|
||||||
|
{"three chars", "abc", true},
|
||||||
|
{"four chars", "NTLM", true},
|
||||||
|
{"five chars not NTLM prefix", "abcde", true},
|
||||||
|
{"nine chars Negotiate without trailing space", "Negotiate", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &NTLMAuthHandler{}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.set {
|
||||||
|
req.Header.Set("Authorization", tc.header)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("getAuthPayload panicked on Authorization=%q: %v", tc.header, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
_, _, err := h.getAuthPayload(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for Authorization=%q, got nil", tc.header)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user