diff --git a/go.mod b/go.mod index 4eb5e8170..53a62ef1c 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.2 github.com/google/uuid v1.2.0 + github.com/gorilla/mux v1.8.0 github.com/kardianos/service v1.2.0 github.com/onsi/ginkgo v1.16.4 github.com/onsi/gomega v1.13.0 diff --git a/go.sum b/go.sum index cf57fefbe..8cd41ec23 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= diff --git a/management/server/account.go b/management/server/account.go index 894a9832f..c7d32ba86 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -8,6 +8,7 @@ import ( "net" "strings" "sync" + "time" ) type AccountManager struct { @@ -77,6 +78,79 @@ func (manager *AccountManager) GetPeersForAPeer(peerKey string) ([]*Peer, error) return res, nil } +//AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account +func (manager *AccountManager) AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn time.Duration) (*SetupKey, error) { + manager.mux.Lock() + defer manager.mux.Unlock() + + account, err := manager.Store.GetAccount(accountId) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + setupKey := GenerateSetupKey(keyName, keyType, expiresIn) + account.SetupKeys[setupKey.Key] = setupKey + + err = manager.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed adding account key") + } + + return setupKey, nil +} + +//RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore +func (manager *AccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) { + manager.mux.Lock() + defer manager.mux.Unlock() + + account, err := manager.Store.GetAccount(accountId) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + setupKey := getAccountSetupKeyById(account, keyId) + if setupKey == nil { + return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId) + } + + keyCopy := setupKey.Copy() + keyCopy.Revoked = true + account.SetupKeys[keyCopy.Key] = keyCopy + err = manager.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed adding account key") + } + + return keyCopy, nil +} + +//RenameSetupKey renames existing setup key of the specified account. +func (manager *AccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) { + manager.mux.Lock() + defer manager.mux.Unlock() + + account, err := manager.Store.GetAccount(accountId) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + setupKey := getAccountSetupKeyById(account, keyId) + if setupKey == nil { + return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId) + } + + keyCopy := setupKey.Copy() + keyCopy.Name = newName + account.SetupKeys[keyCopy.Key] = keyCopy + err = manager.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed adding account key") + } + + return keyCopy, nil +} + //GetAccount returns an existing account or error (NotFound) if doesn't exist func (manager *AccountManager) GetAccount(accountId string) (*Account, error) { manager.mux.Lock() @@ -177,13 +251,7 @@ func (manager *AccountManager) AddPeer(setupKey string, peerKey string) (*Peer, return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", upperKey) } - for _, key := range account.SetupKeys { - if upperKey == key.Key { - sk = key - break - } - } - + sk = getAccountSetupKeyByKey(account, setupKey) if sk == nil { // shouldn't happen actually return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", upperKey) @@ -242,3 +310,21 @@ func newAccount() (*Account, *SetupKey) { accountId := uuid.New().String() return newAccountWithId(accountId) } + +func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey { + for _, k := range acc.SetupKeys { + if keyId == k.Id { + return k + } + } + return nil +} + +func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey { + for _, k := range acc.SetupKeys { + if key == k.Key { + return k + } + } + return nil +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 0bae94d11..a783bb88d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -29,7 +29,7 @@ func TestAccountManager_AddAccount(t *testing.T) { } if account.Id != expectedId { - t.Errorf("expected account to have ID = %s, got %s", expectedId, account.Id) + t.Errorf("expected account to have Id = %s, got %s", expectedId, account.Id) } if len(account.Peers) != expectedPeersSize { @@ -130,7 +130,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } if account.Id != getAccount.Id { - t.Errorf("expected account.ID %s, got %s", account.Id, getAccount.Id) + t.Errorf("expected account.Id %s, got %s", account.Id, getAccount.Id) } for _, peer := range account.Peers { diff --git a/management/server/http/handler/peers.go b/management/server/http/handler/peers.go index be8beecd3..b461ce4b6 100644 --- a/management/server/http/handler/peers.go +++ b/management/server/http/handler/peers.go @@ -28,7 +28,7 @@ func NewPeers(accountManager *server.AccountManager) *Peers { } } -func (h *Peers) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: accountId := extractAccountIdFromRequestContext(r) diff --git a/management/server/http/handler/setupkeys.go b/management/server/http/handler/setupkeys.go index cf59c643d..0667cca1c 100644 --- a/management/server/http/handler/setupkeys.go +++ b/management/server/http/handler/setupkeys.go @@ -2,8 +2,11 @@ package handler import ( "encoding/json" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/management/server" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net/http" "time" ) @@ -15,11 +18,21 @@ type SetupKeys struct { // SetupKeyResponse is a response sent to the client type SetupKeyResponse struct { + Id string Key string Name string Expires time.Time - Type string + Type server.SetupKeyType Valid bool + Revoked bool +} + +// SetupKeyRequest is a request sent by client. This object contains fields that can be modified +type SetupKeyRequest struct { + Name string + Type server.SetupKeyType + ExpiresIn Duration + Revoked bool } func NewSetupKeysHandler(accountManager *server.AccountManager) *SetupKeys { @@ -28,7 +41,90 @@ func NewSetupKeysHandler(accountManager *server.AccountManager) *SetupKeys { } } -func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *SetupKeys) CreateKey(w http.ResponseWriter, r *http.Request) { + accountId := extractAccountIdFromRequestContext(r) + req := &SetupKeyRequest{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, req.Type, req.ExpiresIn.Duration) + if err != nil { + errStatus, ok := status.FromError(err) + if ok && errStatus.Code() == codes.NotFound { + http.Error(w, "account not found", http.StatusNotFound) + return + } + http.Error(w, "failed adding setup key", http.StatusInternalServerError) + return + } + + writeSuccess(w, setupKey) +} + +func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) { + accountId := extractAccountIdFromRequestContext(r) + vars := mux.Vars(r) + keyId := vars["id"] + if len(keyId) == 0 { + http.Error(w, "invalid key Id", http.StatusBadRequest) + return + } + + switch r.Method { + case http.MethodPost: + req := &SetupKeyRequest{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var key *server.SetupKey + if req.Revoked { + //handle only if being revoked, don't allow to enable key again for now + key, err = h.accountManager.RevokeSetupKey(accountId, keyId) + if err != nil { + http.Error(w, "failed revoking key", http.StatusInternalServerError) + return + } + } + if len(req.Name) != 0 { + key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name) + if err != nil { + http.Error(w, "failed renaming key", http.StatusInternalServerError) + return + } + } + + if key != nil { + writeSuccess(w, key) + } + + return + + case http.MethodGet: + account, err := h.accountManager.GetAccount(accountId) + if err != nil { + http.Error(w, "account doesn't exist", http.StatusInternalServerError) + return + } + for _, key := range account.SetupKeys { + if key.Id == keyId { + writeSuccess(w, key) + return + } + } + http.Error(w, "setup key not found", http.StatusNotFound) + return + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: accountId := extractAccountIdFromRequestContext(r) @@ -44,13 +140,7 @@ func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) { respBody := []*SetupKeyResponse{} for _, key := range account.SetupKeys { - respBody = append(respBody, &SetupKeyResponse{ - Key: key.Key, - Name: key.Name, - Expires: key.ExpiresAt, - Type: string(key.Type), - Valid: key.IsValid(), - }) + respBody = append(respBody, toResponseBody(key)) } err = json.NewEncoder(w).Encode(respBody) @@ -63,3 +153,25 @@ func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusNotFound) } } + +func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { + w.WriteHeader(200) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(toResponseBody(key)) + if err != nil { + http.Error(w, "failed handling request", http.StatusInternalServerError) + return + } +} + +func toResponseBody(key *server.SetupKey) *SetupKeyResponse { + return &SetupKeyResponse{ + Id: key.Id, + Key: key.Key, + Name: key.Name, + Expires: key.ExpiresAt, + Type: key.Type, + Valid: key.IsValid(), + Revoked: key.Revoked, + } +} diff --git a/management/server/http/handler/util.go b/management/server/http/handler/util.go index 97c49b78a..b479cd8b8 100644 --- a/management/server/http/handler/util.go +++ b/management/server/http/handler/util.go @@ -1,8 +1,11 @@ package handler import ( + "encoding/json" + "errors" "github.com/golang-jwt/jwt" "net/http" + "time" ) // extractAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth) @@ -13,3 +16,33 @@ func extractAccountIdFromRequestContext(r *http.Request) string { //actually a user id but for now we have a 1 to 1 mapping. return claims["sub"].(string) } + +//Duration is used strictly for JSON requests/responses due to duration marshalling issues +type Duration struct { + time.Duration +} + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + return nil + case string: + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + return nil + default: + return errors.New("invalid duration") + } +} diff --git a/management/server/http/server.go b/management/server/http/server.go index fa45a2b07..3de780182 100644 --- a/management/server/http/server.go +++ b/management/server/http/server.go @@ -2,6 +2,7 @@ package http import ( "context" + "github.com/gorilla/mux" "github.com/rs/cors" log "github.com/sirupsen/logrus" s "github.com/wiretrustee/wiretrustee/management/server" @@ -54,20 +55,24 @@ func (s *Server) Start() error { } corsMiddleware := cors.AllowAll() - h := http.NewServeMux() - s.server.Handler = h + + r := mux.NewRouter() + r.Use(jwtMiddleware.Handler, corsMiddleware.Handler) peersHandler := handler.NewPeers(s.accountManager) keysHandler := handler.NewSetupKeysHandler(s.accountManager) - h.Handle("/api/peers", corsMiddleware.Handler(jwtMiddleware.Handler(peersHandler))) - h.Handle("/api/setup-keys", corsMiddleware.Handler(jwtMiddleware.Handler(keysHandler))) - http.Handle("/", h) + r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") + + r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "OPTIONS") + r.HandleFunc("/api/setup-keys", keysHandler.CreateKey).Methods("PUT") + r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "POST", "OPTIONS") + http.Handle("/", r) if s.certManager != nil { // if HTTPS is enabled we reuse the listener from the cert manager listener := s.certManager.Listener() log.Infof("http server listening on %s", listener.Addr()) - if err = http.Serve(listener, s.certManager.HTTPHandler(h)); err != nil { + if err = http.Serve(listener, s.certManager.HTTPHandler(r)); err != nil { log.Errorf("failed to serve https server: %v", err) return err } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 3d9393b82..573999c2e 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,6 +2,8 @@ package server import ( "github.com/google/uuid" + "hash/fnv" + "strconv" "strings" "time" ) @@ -23,6 +25,7 @@ type SetupKeyType string // SetupKey represents a pre-authorized key used to register machines (peers) type SetupKey struct { + Id string Key string Name string Type SetupKeyType @@ -34,6 +37,20 @@ type SetupKey struct { UsedTimes int } +//Copy copies SetupKey to a new object +func (key *SetupKey) Copy() *SetupKey { + return &SetupKey{ + Id: key.Id, + Key: key.Key, + Name: key.Name, + Type: key.Type, + CreatedAt: key.CreatedAt, + ExpiresAt: key.ExpiresAt, + Revoked: key.Revoked, + UsedTimes: key.UsedTimes, + } +} + // IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to func (key *SetupKey) IsValid() bool { expired := time.Now().After(key.ExpiresAt) @@ -43,9 +60,11 @@ func (key *SetupKey) IsValid() bool { // GenerateSetupKey generates a new setup key func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey { + key := strings.ToUpper(uuid.New().String()) createdAt := time.Now() return &SetupKey{ - Key: strings.ToUpper(uuid.New().String()), + Id: strconv.Itoa(int(Hash(key))), + Key: key, Name: name, Type: t, CreatedAt: createdAt, @@ -59,3 +78,12 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *Setu func GenerateDefaultSetupKey() *SetupKey { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration) } + +func Hash(s string) uint32 { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + panic(err) + } + return h.Sum32() +} diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 03c7e5af5..e4cad8fd7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -2,6 +2,7 @@ package server import ( "github.com/google/uuid" + "strconv" "testing" "time" ) @@ -16,7 +17,8 @@ func TestGenerateDefaultSetupKey(t *testing.T) { key := GenerateDefaultSetupKey() - assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt) + assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, + expectedExpiresAt, strconv.Itoa(int(Hash(key.Key)))) } @@ -30,7 +32,7 @@ func TestGenerateSetupKey(t *testing.T) { key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour) - assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt) + assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key)))) } @@ -68,7 +70,8 @@ func TestSetupKey_IsValid(t *testing.T) { } } -func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time) { +func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, + expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) { if key.Name != expectedName { t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name) } @@ -97,4 +100,17 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke if err != nil { t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) } + + if key.Id != strconv.Itoa(int(Hash(key.Key))) { + t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id) + } +} + +func TestSetupKey_Copy(t *testing.T) { + + key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour) + keyCopy := key.Copy() + + assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id) + } diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 03bc6e3dd..2ac5b03ee 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -109,7 +109,7 @@ var _ = Describe("Client", func() { }) }) - Context("with a raw client and no ID header", func() { + Context("with a raw client and no Id header", func() { It("should fail", func() { client := createRawSignalClient(addr) @@ -125,7 +125,7 @@ var _ = Describe("Client", func() { }) }) - Context("with a raw client and with an ID header", func() { + Context("with a raw client and with an Id header", func() { It("should be successful", func() { md := metadata.New(map[string]string{sigProto.HeaderId: "peer"}) diff --git a/signal/server/signal.go b/signal/server/signal.go index 164eaf6b6..fc102c27c 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -88,7 +88,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) } // Handles initial Peer connection. -// Each connection must provide an ID header. +// Each connection must provide an Id header. // At this moment the connecting Peer will be registered in the peer.Registry func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {