Allow chaining of checks

This commit is contained in:
Bolke de Bruin
2022-08-25 12:12:21 +02:00
parent 9d2dc57e90
commit 768ee45974
3 changed files with 34 additions and 31 deletions

View File

@@ -156,9 +156,9 @@ func main() {
} }
if conf.Caps.TokenAuth { if conf.Caps.TokenAuth {
gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken
gwConfig.VerifyServerFunc = security.VerifyServerFunc gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost)
} else { } else {
gwConfig.VerifyServerFunc = security.BasicVerifyServer gwConfig.VerifyServerFunc = security.CheckHost
} }
gw := protocol.Gateway{ gw := protocol.Gateway{
ServerConf: &gwConfig, ServerConf: &gwConfig,

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"strings"
) )
var ( var (
@@ -12,19 +13,20 @@ var (
HostSelection string HostSelection string
) )
func BasicVerifyServer(ctx context.Context, host string) (bool, error) { func CheckHost(ctx context.Context, host string) (bool, error) {
if HostSelection == "any" { switch HostSelection {
case "any":
return true, nil return true, nil
} case "signed":
// todo get from context?
if HostSelection == "signed" {
// todo get from context
return false, errors.New("cannot verify host in 'signed' mode as token data is missing") return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
} case "roundrobin", "unsigned":
if HostSelection == "roundrobin" || HostSelection == "unsigned" {
log.Printf("Checking host") log.Printf("Checking host")
username := ctx.Value("preferred_username").(string)
for _, h := range Hosts { for _, h := range Hosts {
if username != "" {
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
}
if h == host { if h == host {
return true, nil return true, nil
} }

View File

@@ -33,6 +33,27 @@ type customClaims struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
} }
func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc {
return func(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
return next(ctx, host)
}
}
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
if tokenString == "" { if tokenString == "" {
log.Printf("no token to parse") log.Printf("no token to parse")
@@ -91,26 +112,6 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
return true, nil return true, nil
} }
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
return true, nil
}
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) { func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
if len(SigningKey) < 32 { if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified") return "", errors.New("token signing key not long enough or not specified")