pkg/connector/groups.go
224 linesgo
package connector

import (
	"context"
	"fmt"

	"example/baton-junction/pkg/client"

	v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
	"github.com/conductorone/baton-sdk/pkg/annotations"
	"github.com/conductorone/baton-sdk/pkg/types/entitlement"
	"github.com/conductorone/baton-sdk/pkg/types/grant"
	"github.com/conductorone/baton-sdk/pkg/types/resource"
	"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
	"go.uber.org/zap"
)

type groupBuilder struct {
	client *client.Client
}

func newGroupBuilder(c *client.Client) *groupBuilder {
	return &groupBuilder{client: c}
}

func (o *groupBuilder) ResourceType(_ context.Context) *v2.ResourceType {
	return groupResourceType
}

func (o *groupBuilder) List(
	ctx context.Context,
	_ *v2.ResourceId,
	attrs resource.SyncOpAttrs,
) ([]*v2.Resource, *resource.SyncOpResults, error) {
	l := ctxzap.Extract(ctx)

	bag, cursor, err := parsePageToken(attrs.PageToken.Token, &v2.ResourceId{ResourceType: groupResourceType.Id})
	if err != nil {
		return nil, nil, err
	}

	groups, nextCursor, err := o.client.GetGroups(ctx, cursor)
	if err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to list groups: %w", err)
	}

	l.Debug("listed groups", zap.Int("count", len(groups)), zap.String("next_cursor", nextCursor))

	var resources []*v2.Resource
	for _, g := range groups {
		r, err := groupResource(&g)
		if err != nil {
			return nil, nil, err
		}
		resources = append(resources, r)
	}

	if err := bag.Next(nextCursor); err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to advance pagination: %w", err)
	}

	nextPageToken, err := bag.Marshal()
	if err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to marshal pagination bag: %w", err)
	}

	return resources, &resource.SyncOpResults{NextPageToken: nextPageToken}, nil
}

func (o *groupBuilder) Entitlements(
	_ context.Context,
	res *v2.Resource,
	_ resource.SyncOpAttrs,
) ([]*v2.Entitlement, *resource.SyncOpResults, error) {
	e := entitlement.NewAssignmentEntitlement(
		res,
		memberEntitlement,
		entitlement.WithGrantableTo(userResourceType),
		entitlement.WithDisplayName(fmt.Sprintf("%s Member", res.DisplayName)),
		entitlement.WithDescription(fmt.Sprintf("Member of the %s group", res.DisplayName)),
	)
	return []*v2.Entitlement{e}, nil, nil
}

func (o *groupBuilder) Grants(
	ctx context.Context,
	res *v2.Resource,
	attrs resource.SyncOpAttrs,
) ([]*v2.Grant, *resource.SyncOpResults, error) {
	l := ctxzap.Extract(ctx)

	bag, cursor, err := parsePageToken(attrs.PageToken.Token, &v2.ResourceId{
		ResourceType: groupResourceType.Id,
		Resource:     res.Id.Resource,
	})
	if err != nil {
		return nil, nil, err
	}

	members, nextCursor, err := o.client.GetGroupMembers(ctx, res.Id.Resource, cursor)
	if err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to list group members: %w", err)
	}

	l.Debug("listed group members",
		zap.String("group_id", res.Id.Resource),
		zap.Int("count", len(members)),
	)

	var grants []*v2.Grant
	for _, m := range members {
		g := grant.NewGrant(
			res,
			memberEntitlement,
			&v2.ResourceId{
				ResourceType: userResourceType.Id,
				Resource:     m.UserID,
			},
		)
		grants = append(grants, g)
	}

	if err := bag.Next(nextCursor); err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to advance pagination: %w", err)
	}

	nextPageToken, err := bag.Marshal()
	if err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to marshal pagination bag: %w", err)
	}

	return grants, &resource.SyncOpResults{NextPageToken: nextPageToken}, nil
}

func (o *groupBuilder) Grant(
	ctx context.Context,
	principal *v2.Resource,
	en *v2.Entitlement,
) (annotations.Annotations, error) {
	l := ctxzap.Extract(ctx)

	groupID := en.Resource.Id.Resource
	userID := principal.Id.Resource

	l.Info("granting group membership",
		zap.String("group_id", groupID),
		zap.String("user_id", userID),
	)

	err := o.client.AddGroupMember(ctx, groupID, userID)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to grant group membership: %w", err)
	}

	return nil, nil
}

func (o *groupBuilder) Revoke(
	ctx context.Context,
	g *v2.Grant,
) (annotations.Annotations, error) {
	l := ctxzap.Extract(ctx)

	groupID := g.Entitlement.Resource.Id.Resource
	userID := g.Principal.Id.Resource

	l.Info("revoking group membership",
		zap.String("group_id", groupID),
		zap.String("user_id", userID),
	)

	if err := o.client.RemoveGroupMember(ctx, groupID, userID); err != nil {
		if isNotFound(err) {
			l.Debug("group membership already removed", zap.String("group_id", groupID), zap.String("user_id", userID))
			return nil, nil
		}
		return nil, fmt.Errorf("baton-junction: failed to revoke group membership: %w", err)
	}

	return nil, nil
}

func (o *groupBuilder) Get(
	ctx context.Context,
	resourceId *v2.ResourceId,
	_ *v2.ResourceId,
) (*v2.Resource, annotations.Annotations, error) {
	l := ctxzap.Extract(ctx)

	group, err := o.client.GetGroup(ctx, resourceId.Resource)
	if err != nil {
		return nil, nil, fmt.Errorf("baton-junction: failed to get group %s: %w", resourceId.Resource, err)
	}

	l.Debug("fetched group", zap.String("group_id", resourceId.Resource))

	r, err := groupResource(group)
	if err != nil {
		return nil, nil, err
	}

	return r, nil, nil
}

func groupResource(g *client.Group) (*v2.Resource, error) {
	r, err := resource.NewGroupResource(
		g.Name,
		groupResourceType,
		g.ID,
		[]resource.GroupTraitOption{
			resource.WithGroupProfile(map[string]interface{}{
				"description": g.Description,
			}),
		},
		resource.WithAnnotation(&v2.RawId{Id: g.ID}),
		resource.WithDescription(g.Description),
	)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to build group resource %s: %w", g.ID, err)
	}

	return r, nil
}