[client] fix extra DNS labels parameter to Register method in client (#3371)

[client] fix extra DNS labels parameter to Register method in client (#3371)
This commit is contained in:
hakansa
2025-02-24 14:53:59 +03:00
committed by GitHub
parent cc48594b0b
commit dabdef4d67
6 changed files with 15 additions and 14 deletions

View File

@@ -15,7 +15,7 @@ type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)

View File

@@ -205,7 +205,7 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil)
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -235,7 +235,7 @@ func TestClient_Sync(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -251,7 +251,7 @@ func TestClient_Sync(t *testing.T) {
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -352,7 +352,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil)
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}

View File

@@ -365,12 +365,12 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys})
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.

View File

@@ -14,7 +14,7 @@ type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
@@ -46,11 +46,11 @@ func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) {
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey)
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {