Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 139 additions & 7 deletions service/entityresolution/claims/v2/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package claims

import (
"context"
"errors"
"fmt"
"log"
"log/slog"
"strconv"
"strings"

"connectrpc.com/connect"
"github.com/go-viper/mapstructure/v2"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/protocol/go/entity"
entityresolutionV2 "github.com/opentdf/platform/protocol/go/entityresolution/v2"
Expand All @@ -22,17 +26,30 @@ import (

type EntityResolutionServiceV2 struct {
entityresolutionV2.UnimplementedEntityResolutionServiceServer
logger *logger.Logger
logger *logger.Logger
allowDirectEntitlements bool
trace.Tracer
}

func RegisterClaimsERS(_ config.ServiceConfig, logger *logger.Logger) (EntityResolutionServiceV2, serviceregistry.HandlerServer) {
claimsSVC := EntityResolutionServiceV2{logger: logger}
type Config struct {
AllowDirectEntitlements bool `mapstructure:"allow_direct_entitlements" json:"allow_direct_entitlements" default:"false"`
}

func RegisterClaimsERS(cfg config.ServiceConfig, logger *logger.Logger) (EntityResolutionServiceV2, serviceregistry.HandlerServer) {
var inputConfig Config
if err := mapstructure.Decode(cfg, &inputConfig); err != nil {
logger.Error("failed to decode claims entity resolution configuration", slog.Any("error", err))
log.Fatalf("Failed to decode claims entity resolution configuration: %v", err)
}
claimsSVC := EntityResolutionServiceV2{
logger: logger,
allowDirectEntitlements: inputConfig.AllowDirectEntitlements,
}
return claimsSVC, nil
}

func (s EntityResolutionServiceV2) ResolveEntities(ctx context.Context, req *connect.Request[entityresolutionV2.ResolveEntitiesRequest]) (*connect.Response[entityresolutionV2.ResolveEntitiesResponse], error) {
resp, err := EntityResolution(ctx, req.Msg, s.logger)
resp, err := EntityResolution(ctx, req.Msg, s.logger, s.allowDirectEntitlements)
return connect.NewResponse(&resp), err
}

Expand Down Expand Up @@ -63,13 +80,14 @@ func CreateEntityChainsFromTokens(
}

func EntityResolution(_ context.Context,
req *entityresolutionV2.ResolveEntitiesRequest, logger *logger.Logger,
req *entityresolutionV2.ResolveEntitiesRequest, logger *logger.Logger, allowDirectEntitlements bool,
) (entityresolutionV2.ResolveEntitiesResponse, error) {
payload := req.GetEntities()
var resolvedEntities []*entityresolutionV2.EntityRepresentation

for idx, ident := range payload {
entityStruct := &structpb.Struct{}
var directEntitlements []*entityresolutionV2.DirectEntitlement
switch ident.GetEntityType().(type) {
case *entity.Entity_Claims:
claims := ident.GetClaims()
Expand All @@ -79,6 +97,13 @@ func EntityResolution(_ context.Context,
return entityresolutionV2.ResolveEntitiesResponse{}, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err))
}
}
if allowDirectEntitlements {
var err error
directEntitlements, err = parseDirectEntitlementsFromClaims(entityStruct)
if err != nil {
return entityresolutionV2.ResolveEntitiesResponse{}, connect.NewError(connect.CodeInvalidArgument, err)
}
}
default:
retrievedStruct, err := entityToStructPb(ident)
if err != nil {
Expand All @@ -95,8 +120,9 @@ func EntityResolution(_ context.Context,
resolvedEntities = append(
resolvedEntities,
&entityresolutionV2.EntityRepresentation{
OriginalId: originialID,
AdditionalProps: []*structpb.Struct{entityStruct},
OriginalId: originialID,
AdditionalProps: []*structpb.Struct{entityStruct},
DirectEntitlements: directEntitlements,
},
)
}
Expand Down Expand Up @@ -164,3 +190,109 @@ func entityToStructPb(ident *entity.Entity) (*structpb.Struct, error) {
}
return &entityStruct, nil
}

func parseDirectEntitlementsFromClaims(entityStruct *structpb.Struct) ([]*entityresolutionV2.DirectEntitlement, error) {
if entityStruct == nil {
return nil, nil
}
claims := entityStruct.AsMap()
rawEntitlements, ok := claims["direct_entitlements"]
if !ok {
rawEntitlements, ok = claims["directEntitlements"]
}
if !ok {
return nil, nil
}

entitlementList, entitlementsOK := rawEntitlements.([]interface{})
if !entitlementsOK {
return nil, errors.New("direct_entitlements must be an array")
}

out := make([]*entityresolutionV2.DirectEntitlement, 0, len(entitlementList))
for idx, entry := range entitlementList {
entryMap, entryOK := entry.(map[string]interface{})
if !entryOK {
return nil, fmt.Errorf("direct_entitlements[%d] must be an object", idx)
}

fqn, err := parseDirectEntitlementFQN(entryMap)
if err != nil {
return nil, fmt.Errorf("direct_entitlements[%d] %w", idx, err)
}

rawActions, actionsOK := entryMap["actions"]
if !actionsOK {
return nil, fmt.Errorf("direct_entitlements[%d] missing actions", idx)
}
actions, err := parseDirectEntitlementActions(rawActions)
if err != nil {
return nil, fmt.Errorf("direct_entitlements[%d] invalid actions: %w", idx, err)
}

out = append(out, &entityresolutionV2.DirectEntitlement{
AttributeValueFqn: fqn,
Actions: actions,
})
}

return out, nil
}

func parseDirectEntitlementFQN(entry map[string]interface{}) (string, error) {
if raw, ok := entry["attribute_value_fqn"]; ok {
if fqn, fqnOK := raw.(string); fqnOK {
fqn = strings.TrimSpace(fqn)
if fqn != "" {
return fqn, nil
}
}
}
if raw, ok := entry["attributeValueFqn"]; ok {
if fqn, fqnOK := raw.(string); fqnOK {
fqn = strings.TrimSpace(fqn)
if fqn != "" {
return fqn, nil
}
}
}
return "", errors.New("missing attribute_value_fqn")
}

func parseDirectEntitlementActions(raw interface{}) ([]string, error) {
actions := make([]string, 0)
switch typed := raw.(type) {
case []interface{}:
for _, action := range typed {
actionStr, ok := action.(string)
if !ok {
return nil, errors.New("action must be a string")
}
actionStr = strings.TrimSpace(strings.ToLower(actionStr))
if actionStr != "" {
actions = append(actions, actionStr)
}
}
case []string:
for _, action := range typed {
action = strings.TrimSpace(strings.ToLower(action))
if action != "" {
actions = append(actions, action)
}
}
case string:
for _, action := range strings.Split(typed, ",") {
action = strings.TrimSpace(strings.ToLower(action))
if action != "" {
actions = append(actions, action)
}
}
default:
return nil, errors.New("actions must be an array or string")
}

if len(actions) == 0 {
return nil, errors.New("no actions provided")
}
return actions, nil
}
65 changes: 62 additions & 3 deletions service/entityresolution/claims/v2/entity_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func Test_ClientResolveEntity(t *testing.T) {
req := entityresolutionV2.ResolveEntitiesRequest{}
req.Entities = validBody

resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger())
resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger(), false)

require.NoError(t, reserr)

Expand All @@ -44,7 +44,7 @@ func Test_EmailResolveEntity(t *testing.T) {
req := entityresolutionV2.ResolveEntitiesRequest{}
req.Entities = validBody

resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger())
resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger(), false)

require.NoError(t, reserr)

Expand Down Expand Up @@ -78,7 +78,7 @@ func Test_ClaimsResolveEntity(t *testing.T) {
req := entityresolutionV2.ResolveEntitiesRequest{}
req.Entities = validBody

resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger())
resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger(), false)

require.NoError(t, reserr)

Expand All @@ -93,6 +93,65 @@ func Test_ClaimsResolveEntity(t *testing.T) {
assert.EqualValues(t, 42, propMap["baz"])
}

func Test_ClaimsResolveEntityDirectEntitlements(t *testing.T) {
customclaims := map[string]interface{}{
"direct_entitlements": []interface{}{
map[string]interface{}{
"attribute_value_fqn": "https://example.com/attr/department/value/eng",
"actions": []interface{}{"read", "update"},
},
},
}
structClaims, err := structpb.NewStruct(customclaims)
require.NoError(t, err)

anyClaims, err := anypb.New(structClaims)
require.NoError(t, err)

var validBody []*entity.Entity
validBody = append(validBody, &entity.Entity{EphemeralId: "1234", EntityType: &entity.Entity_Claims{Claims: anyClaims}})

req := entityresolutionV2.ResolveEntitiesRequest{Entities: validBody}

resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger(), true)
require.NoError(t, reserr)

entityRepresentations := resp.GetEntityRepresentations()
require.Len(t, entityRepresentations, 1)

entitlements := entityRepresentations[0].GetDirectEntitlements()
require.Len(t, entitlements, 1)
assert.Equal(t, "https://example.com/attr/department/value/eng", entitlements[0].GetAttributeValueFqn())
assert.ElementsMatch(t, []string{"read", "update"}, entitlements[0].GetActions())
}

func Test_ClaimsResolveEntityDirectEntitlementsDisabled(t *testing.T) {
customclaims := map[string]interface{}{
"direct_entitlements": []interface{}{
map[string]interface{}{
"attribute_value_fqn": "https://example.com/attr/department/value/eng",
"actions": []interface{}{"read"},
},
},
}
structClaims, err := structpb.NewStruct(customclaims)
require.NoError(t, err)

anyClaims, err := anypb.New(structClaims)
require.NoError(t, err)

req := entityresolutionV2.ResolveEntitiesRequest{Entities: []*entity.Entity{
{EphemeralId: "1234", EntityType: &entity.Entity_Claims{Claims: anyClaims}},
}}

resp, reserr := claims.EntityResolution(t.Context(), &req, logger.CreateTestLogger(), false)
require.NoError(t, reserr)

entityRepresentations := resp.GetEntityRepresentations()
require.Len(t, entityRepresentations, 1)
assert.Empty(t, entityRepresentations[0].GetDirectEntitlements())
}

func Test_JWTToEntityChainClaims(t *testing.T) {
validBody := []*entity.Token{{Jwt: samplejwt}}

Expand Down
51 changes: 51 additions & 0 deletions tests-bdd/cukes/resources/platform.direct_entitlements.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
authEndpoint: &authEndpoint http://{{ .hostname }}:{{.kcPort }}/auth
issuerEndpoint: &issuerEndpoint http://{{ .hostname }}:{{.kcPort }}/auth/realms/{{.authRealm}}
tokenEndpoint: &tokenEndpoint http://{{ .hostname }}:{{.kcPort }}/auth/realms/{{.authRealm}}/protocol/openid-connect/token
entityResolutionServiceUrl: &entityResolutionServiceUrl https://{{ .hostname }}:{{.platformPort }}/entityresolution/resolve
platformEndpoint: &platformEndpoint https://{{.hostname }}:{{.platformPort }}
authRealm: &authRealm {{.authRealm}}
mode: all
logger:
level: debug
type: text
output: stdout
server:
port: {{.platformPort}}
auth:
enabled: true
enforceDPoP: false
audience: *platformEndpoint
issuer: *issuerEndpoint
policy:
extension: |
g, opentdf-admin, role:admin
g, opentdf-standard, role:standard
db:
host: {{ .pgHost }}
port: {{ .pgPort }}
database: {{ .pgDatabase }}
user: postgres
password: changeme
schema: otdf
services:
authorization:
allow_direct_entitlements: true
kas:
keyring:
- kid: e1
alg: ec:secp256r1
- kid: r1
alg: rsa:2048
entityresolution:
mode: claims
allow_direct_entitlements: true
shared:
clientId: otdf-shared
clientSecret: secret
authClientId: otdf-shared-auth
serviceHostName: shared
platformEndpoint: *platformEndpoint
platformAuthEndpoint: *authEndpoint
platformAuthRealm: *authRealm
tokenEndpoint: *tokenEndpoint
# ...other service configs as needed...
Loading
Loading