diff --git a/v2/account_claims.go b/v2/account_claims.go index a62d488..e135ab9 100644 --- a/v2/account_claims.go +++ b/v2/account_claims.go @@ -357,6 +357,9 @@ func (a *AccountClaims) ExpectedPrefixes() []nkeys.PrefixByte { func (a *AccountClaims) Claims() *ClaimsData { return &a.ClaimsData } +func (a *AccountClaims) GetTags() TagList { + return a.Account.Tags +} // DidSign checks the claims against the account's public key and its signing keys func (a *AccountClaims) DidSign(c Claims) bool { diff --git a/v2/account_claims_test.go b/v2/account_claims_test.go index 899d65d..6d5fc14 100644 --- a/v2/account_claims_test.go +++ b/v2/account_claims_test.go @@ -839,3 +839,40 @@ func TestAccountClaims_DidSign(t *testing.T) { t.Fatal("this is not issued by account A") } } + +func TestAccountClaims_GetTags(t *testing.T) { + akp := createAccountNKey(t) + apk := publicKey(akp, t) + + ac := NewAccountClaims(apk) + ac.Account.Tags.Add("foo", "bar") + tags := ac.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } + + token, err := ac.Encode(akp) + if err != nil { + t.Fatal("error encoding") + } + ac, err = DecodeAccountClaims(token) + if err != nil { + t.Fatal("error decoding") + } + tags = ac.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } +} diff --git a/v2/operator_claims.go b/v2/operator_claims.go index f806002..3835b97 100644 --- a/v2/operator_claims.go +++ b/v2/operator_claims.go @@ -243,3 +243,7 @@ func (oc *OperatorClaims) Claims() *ClaimsData { func (oc *OperatorClaims) updateVersion() { oc.GenericFields.Version = libVersion } + +func (oc *OperatorClaims) GetTags() TagList { + return oc.Operator.Tags +} diff --git a/v2/operator_claims_test.go b/v2/operator_claims_test.go index 6117c81..dcc4d34 100644 --- a/v2/operator_claims_test.go +++ b/v2/operator_claims_test.go @@ -465,3 +465,40 @@ func TestTags(t *testing.T) { AssertTrue(oc.GenericFields.Tags.Contains("two"), t) AssertTrue(oc.GenericFields.Tags.Contains("three"), t) } + +func TestOperatorClaims_GetTags(t *testing.T) { + okp := createOperatorNKey(t) + opk := publicKey(okp, t) + + oc := NewOperatorClaims(opk) + oc.Operator.Tags.Add("foo", "bar") + tags := oc.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } + + token, err := oc.Encode(okp) + if err != nil { + t.Fatal("error encoding") + } + oc, err = DecodeOperatorClaims(token) + if err != nil { + t.Fatal("error decoding") + } + tags = oc.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } +} diff --git a/v2/user_claims.go b/v2/user_claims.go index d8a3a6c..0b38af6 100644 --- a/v2/user_claims.go +++ b/v2/user_claims.go @@ -151,3 +151,7 @@ func (u *UserClaims) updateVersion() { func (u *UserClaims) IsBearerToken() bool { return u.BearerToken } + +func (u *UserClaims) GetTags() TagList { + return u.User.Tags +} diff --git a/v2/user_claims_test.go b/v2/user_claims_test.go index dea6193..fed0660 100644 --- a/v2/user_claims_test.go +++ b/v2/user_claims_test.go @@ -405,3 +405,41 @@ func TestUserClaimRevocation(t *testing.T) { t.Fatal("account validation shouldn't have failed") } } + +func TestUserClaims_GetTags(t *testing.T) { + akp := createAccountNKey(t) + ukp := createUserNKey(t) + upk := publicKey(ukp, t) + + uc := NewUserClaims(upk) + uc.User.Tags.Add("foo", "bar") + tags := uc.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } + + token, err := uc.Encode(akp) + if err != nil { + t.Fatal("error encoding") + } + uc, err = DecodeUserClaims(token) + if err != nil { + t.Fatal("error decoding") + } + tags = uc.GetTags() + if len(tags) != 2 { + t.Fatal("expected 2 tags") + } + if tags[0] != "foo" { + t.Fatal("expected tag foo") + } + if tags[1] != "bar" { + t.Fatal("expected tag bar") + } +}