package test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"example/baton-junction/pkg/client"
cfg "example/baton-junction/pkg/config"
"example/baton-junction/pkg/connector"
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
"github.com/conductorone/baton-sdk/pkg/actions"
"github.com/conductorone/baton-sdk/pkg/annotations"
"github.com/conductorone/baton-sdk/pkg/connectorbuilder"
"github.com/conductorone/baton-sdk/pkg/pagination"
"github.com/conductorone/baton-sdk/pkg/types/resource"
"google.golang.org/protobuf/types/known/structpb"
)
func writeJSON(t *testing.T, w http.ResponseWriter, v any) {
t.Helper()
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(v); err != nil {
t.Fatalf("failed to write JSON response: %v", err)
}
}
// mockAPI returns an httptest.Server that simulates the full target app REST API,
// including OAuth token exchange, paginated list endpoints, member lookups,
// and provisioning (grant/revoke) actions.
func mockAPI(t *testing.T) *httptest.Server {
t.Helper()
users := []client.User{
{ID: "u1", Email: "alice@appville.com", FirstName: "Alice", LastName: "Smith", Username: "alice", Status: "active"},
{ID: "u2", Email: "bob@appville.com", FirstName: "Bob", LastName: "Jones", Username: "bob", Status: "active"},
{ID: "u3", Email: "carol@appville.com", FirstName: "Carol", LastName: "White", Username: "carol", Status: "active"},
{ID: "u4", Email: "dave@appville.com", FirstName: "Dave", LastName: "Brown", Username: "dave", Status: "active"},
{ID: "u5", Email: "eve@appville.com", FirstName: "Eve", LastName: "Davis", Username: "eve", Status: "active"},
{ID: "u6", Email: "frank@appville.com", FirstName: "Frank", LastName: "Miller", Username: "frank", Status: "active"},
}
roles := []client.Role{
{ID: "r1", Name: "Single Ride", Description: "One-time journey on any standard route"},
{ID: "r2", Name: "Standard Day", Description: "Unlimited standard travel for the day"},
{ID: "r3", Name: "Dining Access", Description: "Access to the dining car service"},
{ID: "r4", Name: "First Class", Description: "Premium seating with complimentary dining"},
{ID: "r5", Name: "Express", Description: "Access to express routes with priority boarding"},
}
groups := []client.Group{
{ID: "g1", Name: "Regional Pass", Description: "Monthly pass for standard daily travel"},
{ID: "g2", Name: "Express Pass", Description: "Monthly pass for express routes"},
{ID: "g3", Name: "All-Access Pass", Description: "Premium monthly pass with first-class travel and dining"},
}
roleMembers := map[string][]client.Member{
"r1": {{UserID: "u6"}},
"r2": {{UserID: "u4"}},
"r3": {{UserID: "u2"}, {UserID: "u3"}, {UserID: "u4"}},
"r4": {{UserID: "u5"}},
}
groupMembers := map[string][]client.Member{
"g1": {{UserID: "u2"}},
"g2": {{UserID: "u3"}},
"g3": {{UserID: "u1"}},
}
roleGroups := map[string][]client.GroupAssignment{
"r2": {{GroupID: "g1"}},
"r4": {{GroupID: "g3"}},
"r5": {{GroupID: "g2"}},
}
roleRoles := map[string][]client.RoleAssignment{
"r3": {{RoleID: "r4"}},
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, client.TokenResponse{
AccessToken: "integration-test-token",
ExpiresIn: 3600,
TokenType: "Bearer",
})
})
mux.HandleFunc("/api/users/me", func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, client.SingleResponse[client.User]{Data: users[0]})
})
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
var req client.CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(client.SingleResponse[client.User]{
Data: client.User{
ID: "u-created", Email: req.Email,
FirstName: req.FirstName, LastName: req.LastName,
Username: req.Username, Status: "active",
},
}); err != nil {
t.Errorf("failed to encode create user response: %v", err)
}
return
}
cursor := r.URL.Query().Get("cursor")
switch cursor {
case "":
writeJSON(t, w, client.ListResponse[client.User]{Data: users[:2], NextCursor: "users-page2"})
case "users-page2":
writeJSON(t, w, client.ListResponse[client.User]{Data: users[2:4], NextCursor: "users-page3"})
case "users-page3":
writeJSON(t, w, client.ListResponse[client.User]{Data: users[4:], NextCursor: ""})
}
})
mux.HandleFunc("/api/roles", func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, client.ListResponse[client.Role]{Data: roles})
})
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, client.ListResponse[client.Group]{Data: groups})
})
mux.HandleFunc("/api/users/", func(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/api/users/")
if id == "" || id == "me" {
return
}
if r.Method == http.MethodPatch {
var attrs map[string]string
if err := json.NewDecoder(r.Body).Decode(&attrs); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
for i, u := range users {
if u.ID == id {
if v, ok := attrs["status"]; ok {
users[i].Status = v
}
if v, ok := attrs["first_name"]; ok {
users[i].FirstName = v
}
if v, ok := attrs["last_name"]; ok {
users[i].LastName = v
}
if v, ok := attrs["email"]; ok {
users[i].Email = v
}
if v, ok := attrs["department"]; ok {
users[i].Department = v
}
writeJSON(t, w, client.SingleResponse[client.User]{Data: users[i]})
return
}
}
http.NotFound(w, r)
return
}
for _, u := range users {
if u.ID == id {
writeJSON(t, w, client.SingleResponse[client.User]{Data: u})
return
}
}
http.NotFound(w, r)
})
mux.HandleFunc("/api/roles/", func(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/roles/"), "/")
roleID := parts[0]
if len(parts) >= 2 && parts[1] == "members" {
if len(parts) == 3 {
switch r.Method {
case http.MethodPut:
w.WriteHeader(http.StatusNoContent)
case http.MethodDelete:
w.WriteHeader(http.StatusNoContent)
}
return
}
members := roleMembers[roleID]
writeJSON(t, w, client.ListResponse[client.Member]{Data: members})
return
}
if len(parts) == 2 && parts[1] == "groups" {
writeJSON(t, w, client.ListResponse[client.GroupAssignment]{Data: roleGroups[roleID]})
return
}
if len(parts) == 2 && parts[1] == "roles" {
writeJSON(t, w, client.ListResponse[client.RoleAssignment]{Data: roleRoles[roleID]})
return
}
if len(parts) == 1 {
for _, rl := range roles {
if rl.ID == roleID {
writeJSON(t, w, client.SingleResponse[client.Role]{Data: rl})
return
}
}
}
http.NotFound(w, r)
})
mux.HandleFunc("/api/groups/", func(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/groups/"), "/")
groupID := parts[0]
if len(parts) >= 2 && parts[1] == "members" {
if len(parts) == 3 {
switch r.Method {
case http.MethodPut:
w.WriteHeader(http.StatusNoContent)
case http.MethodDelete:
w.WriteHeader(http.StatusNoContent)
}
return
}
members := groupMembers[groupID]
writeJSON(t, w, client.ListResponse[client.Member]{Data: members})
return
}
if len(parts) == 1 {
for _, g := range groups {
if g.ID == groupID {
writeJSON(t, w, client.SingleResponse[client.Group]{Data: g})
return
}
}
}
http.NotFound(w, r)
})
ts := httptest.NewServer(mux)
t.Cleanup(ts.Close)
return ts
}
func TestIntegration_FullSync(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
if len(syncers) != 3 {
t.Fatalf("expected 3 syncers, got %d", len(syncers))
}
t.Run("users sync", func(t *testing.T) {
userSyncer := syncers[0]
if userSyncer.ResourceType(ctx).Id != "user" {
t.Fatalf("expected user syncer first, got %s", userSyncer.ResourceType(ctx).Id)
}
var allResources []*v2.Resource
pageToken := ""
for {
resources, results, err := userSyncer.List(ctx, nil, resource.SyncOpAttrs{PageToken: pagination.Token{Token: pageToken}})
if err != nil {
t.Fatalf("failed to list users: %v", err)
}
allResources = append(allResources, resources...)
if results == nil || results.NextPageToken == "" {
break
}
pageToken = results.NextPageToken
}
if len(allResources) != 6 {
t.Fatalf("expected 6 users, got %d", len(allResources))
}
if allResources[0].DisplayName != "Alice Smith" {
t.Errorf("expected first user Alice Smith, got %s", allResources[0].DisplayName)
}
})
t.Run("roles sync", func(t *testing.T) {
roleSyncer := syncers[2]
if roleSyncer.ResourceType(ctx).Id != "role" {
t.Fatalf("expected role syncer, got %s", roleSyncer.ResourceType(ctx).Id)
}
resources, _, err := roleSyncer.List(ctx, nil, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to list roles: %v", err)
}
if len(resources) != 5 {
t.Fatalf("expected 5 roles, got %d", len(resources))
}
ents, _, err := roleSyncer.Entitlements(ctx, resources[0], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get entitlements: %v", err)
}
if len(ents) != 1 {
t.Fatalf("expected 1 entitlement per role, got %d", len(ents))
}
// r1 (Single Ride) has 1 direct user grant (u6)
grants, _, err := roleSyncer.Grants(ctx, resources[0], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get grants: %v", err)
}
if len(grants) != 1 {
t.Fatalf("expected 1 grant for Single Ride role, got %d", len(grants))
}
if grants[0].Principal.Id.Resource != "u6" {
t.Errorf("expected grant principal u6, got %s", grants[0].Principal.Id.Resource)
}
// r2 (Standard Day) has 1 direct user + 1 expandable group = 2 grants
grantsR2, _, err := roleSyncer.Grants(ctx, resources[1], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get Standard Day grants: %v", err)
}
if len(grantsR2) != 2 {
t.Fatalf("expected 2 grants for Standard Day (1 user + 1 group), got %d", len(grantsR2))
}
// r3 (Dining Access) has 3 direct users + 1 expandable role = 4 grants
grantsR3, _, err := roleSyncer.Grants(ctx, resources[2], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get Dining Access grants: %v", err)
}
if len(grantsR3) != 4 {
t.Fatalf("expected 4 grants for Dining Access (3 users + 1 role), got %d", len(grantsR3))
}
})
t.Run("groups sync", func(t *testing.T) {
groupSyncer := syncers[1]
if groupSyncer.ResourceType(ctx).Id != "group" {
t.Fatalf("expected group syncer, got %s", groupSyncer.ResourceType(ctx).Id)
}
resources, _, err := groupSyncer.List(ctx, nil, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to list groups: %v", err)
}
if len(resources) != 3 {
t.Fatalf("expected 3 groups, got %d", len(resources))
}
ents, _, err := groupSyncer.Entitlements(ctx, resources[0], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get entitlements: %v", err)
}
if len(ents) != 1 {
t.Fatalf("expected 1 entitlement per group, got %d", len(ents))
}
// g1 (Regional Pass) has 1 member (u2)
grants, _, err := groupSyncer.Grants(ctx, resources[0], resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get grants: %v", err)
}
if len(grants) != 1 {
t.Fatalf("expected 1 grant for Regional Pass group, got %d", len(grants))
}
})
}
func TestIntegration_Validate(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
if _, err := cb.Validate(ctx); err != nil {
t.Fatalf("validation failed: %v", err)
}
}
func TestIntegration_Metadata(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
md, err := cb.Metadata(ctx)
if err != nil {
t.Fatalf("metadata failed: %v", err)
}
if md.DisplayName != "Baton App" {
t.Errorf("expected display name Baton App, got %s", md.DisplayName)
}
}
func TestIntegration_GroupProvisioning(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
groupSyncer := syncers[1]
type provisioner interface {
Grant(ctx context.Context, principal *v2.Resource, entitlement *v2.Entitlement) (annotations.Annotations, error)
Revoke(ctx context.Context, grant *v2.Grant) (annotations.Annotations, error)
}
gs, ok := groupSyncer.(provisioner)
if !ok {
t.Fatal("group syncer does not implement provisioner interface")
}
groupResource := &v2.Resource{
Id: &v2.ResourceId{ResourceType: "group", Resource: "g1"},
DisplayName: "Regional Pass",
}
userResource := &v2.Resource{
Id: &v2.ResourceId{ResourceType: "user", Resource: "u-new"},
DisplayName: "New User",
}
ent := &v2.Entitlement{
Resource: groupResource,
Slug: "member",
}
t.Run("grant membership", func(t *testing.T) {
if _, err := gs.Grant(ctx, userResource, ent); err != nil {
t.Fatalf("grant failed: %v", err)
}
})
t.Run("revoke membership", func(t *testing.T) {
g := &v2.Grant{Entitlement: ent, Principal: userResource}
if _, err := gs.Revoke(ctx, g); err != nil {
t.Fatalf("revoke failed: %v", err)
}
})
}
func TestIntegration_RoleProvisioning(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
roleSyncer := syncers[2]
type provisioner interface {
Grant(ctx context.Context, principal *v2.Resource, entitlement *v2.Entitlement) (annotations.Annotations, error)
Revoke(ctx context.Context, grant *v2.Grant) (annotations.Annotations, error)
}
rs, ok := roleSyncer.(provisioner)
if !ok {
t.Fatal("role syncer does not implement provisioner interface")
}
roleResource := &v2.Resource{
Id: &v2.ResourceId{ResourceType: "role", Resource: "r1"},
DisplayName: "Single Ride",
}
userResource := &v2.Resource{
Id: &v2.ResourceId{ResourceType: "user", Resource: "u2"},
DisplayName: "Bob Jones",
}
ent := &v2.Entitlement{Resource: roleResource, Slug: "member"}
t.Run("grant role", func(t *testing.T) {
if _, err := rs.Grant(ctx, userResource, ent); err != nil {
t.Fatalf("grant failed: %v", err)
}
})
t.Run("revoke role", func(t *testing.T) {
g := &v2.Grant{Entitlement: ent, Principal: userResource}
if _, err := rs.Revoke(ctx, g); err != nil {
t.Fatalf("revoke failed: %v", err)
}
})
}
func TestIntegration_PaginatedUserSync(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
userSyncer := syncers[0]
page1, results1, err := userSyncer.List(ctx, nil, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to list first page: %v", err)
}
if len(page1) != 2 {
t.Fatalf("expected 2 users on first page, got %d", len(page1))
}
if results1 == nil || results1.NextPageToken == "" {
t.Fatal("expected non-empty next token after first page")
}
page2, results2, err := userSyncer.List(ctx, nil, resource.SyncOpAttrs{PageToken: pagination.Token{Token: results1.NextPageToken}})
if err != nil {
t.Fatalf("failed to list second page: %v", err)
}
if len(page2) != 2 {
t.Fatalf("expected 2 users on second page, got %d", len(page2))
}
if results2 == nil || results2.NextPageToken == "" {
t.Fatal("expected non-empty next token after second page")
}
page3, results3, err := userSyncer.List(ctx, nil, resource.SyncOpAttrs{PageToken: pagination.Token{Token: results2.NextPageToken}})
if err != nil {
t.Fatalf("failed to list third page: %v", err)
}
if len(page3) != 2 {
t.Fatalf("expected 2 users on third page, got %d", len(page3))
}
if results3 != nil && results3.NextPageToken != "" {
t.Errorf("expected empty next token after last page, got %s", results3.NextPageToken)
}
if total := len(page1) + len(page2) + len(page3); total != 6 {
t.Errorf("expected 6 total users across pages, got %d", total)
}
}
func TestIntegration_TargetedSync(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
type targetedSyncer interface {
Get(ctx context.Context, resourceId *v2.ResourceId, parentResourceId *v2.ResourceId) (*v2.Resource, annotations.Annotations, error)
}
t.Run("get user by ID", func(t *testing.T) {
us, ok := syncers[0].(targetedSyncer)
if !ok {
t.Fatal("user builder does not implement ResourceTargetedSyncer")
}
r, _, err := us.Get(ctx, &v2.ResourceId{ResourceType: "user", Resource: "u1"}, nil)
if err != nil {
t.Fatalf("failed to get user: %v", err)
}
if r.DisplayName != "Alice Smith" {
t.Errorf("expected Alice Smith, got %s", r.DisplayName)
}
if r.Id.Resource != "u1" {
t.Errorf("expected u1, got %s", r.Id.Resource)
}
})
t.Run("get group by ID", func(t *testing.T) {
gs, ok := syncers[1].(targetedSyncer)
if !ok {
t.Fatal("group builder does not implement ResourceTargetedSyncer")
}
r, _, err := gs.Get(ctx, &v2.ResourceId{ResourceType: "group", Resource: "g1"}, nil)
if err != nil {
t.Fatalf("failed to get group: %v", err)
}
if r.DisplayName != "Regional Pass" {
t.Errorf("expected Regional Pass, got %s", r.DisplayName)
}
})
t.Run("get role by ID", func(t *testing.T) {
rs, ok := syncers[2].(targetedSyncer)
if !ok {
t.Fatal("role builder does not implement ResourceTargetedSyncer")
}
r, _, err := rs.Get(ctx, &v2.ResourceId{ResourceType: "role", Resource: "r1"}, nil)
if err != nil {
t.Fatalf("failed to get role: %v", err)
}
if r.DisplayName != "Single Ride" {
t.Errorf("expected Single Ride, got %s", r.DisplayName)
}
})
}
func TestIntegration_CreateAccount(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
userSyncer := syncers[0]
type accountManager interface {
CreateAccountCapabilityDetails(ctx context.Context) (*v2.CredentialDetailsAccountProvisioning, annotations.Annotations, error)
CreateAccount(ctx context.Context, accountInfo *v2.AccountInfo, credentialOptions *v2.LocalCredentialOptions) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error)
}
am, ok := userSyncer.(accountManager)
if !ok {
t.Fatal("user builder does not implement AccountManagerLimited")
}
t.Run("capability details", func(t *testing.T) {
details, _, err := am.CreateAccountCapabilityDetails(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if details.PreferredCredentialOption == v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_UNSPECIFIED {
t.Error("preferred credential option should not be unspecified")
}
if len(details.SupportedCredentialOptions) == 0 {
t.Error("expected at least one supported credential option")
}
})
t.Run("create account", func(t *testing.T) {
profile, _ := structpb.NewStruct(map[string]interface{}{
"first_name": "New",
"last_name": "User",
})
accountInfo := &v2.AccountInfo{
Login: "newuser",
Emails: []*v2.AccountInfo_Email{{Address: "new@test.com"}},
Profile: profile,
}
result, _, _, err := am.CreateAccount(ctx, accountInfo, nil)
if err != nil {
t.Fatalf("create account failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
})
}
func TestIntegration_GlobalActions(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
type globalActionProvider interface {
GlobalActions(ctx context.Context, registry actions.ActionRegistry) error
}
gap, ok := cb.(globalActionProvider)
if !ok {
t.Fatal("connector does not implement GlobalActionProvider")
}
t.Run("enable user", func(t *testing.T) {
var capturedSchema *v2.BatonActionSchema
var capturedHandler actions.ActionHandler
mockRegistry := &testActionRegistry{
registerFn: func(_ context.Context, schema *v2.BatonActionSchema, handler actions.ActionHandler) error {
if schema.GetName() == "enable_user" {
capturedSchema = schema
capturedHandler = handler
}
return nil
},
}
if err := gap.GlobalActions(ctx, mockRegistry); err != nil {
t.Fatalf("GlobalActions registration failed: %v", err)
}
if capturedSchema == nil {
t.Fatal("enable_user action was not registered")
}
args, _ := structpb.NewStruct(map[string]interface{}{
"resource_id": "u3",
})
result, _, err := capturedHandler(ctx, args)
if err != nil {
t.Fatalf("enable_user action failed: %v", err)
}
if !result.Fields["success"].GetBoolValue() {
t.Error("expected success=true")
}
})
t.Run("disable user", func(t *testing.T) {
var capturedHandler actions.ActionHandler
mockRegistry := &testActionRegistry{
registerFn: func(_ context.Context, schema *v2.BatonActionSchema, handler actions.ActionHandler) error {
if schema.GetName() == "disable_user" {
capturedHandler = handler
}
return nil
},
}
if err := gap.GlobalActions(ctx, mockRegistry); err != nil {
t.Fatalf("GlobalActions registration failed: %v", err)
}
if capturedHandler == nil {
t.Fatal("disable_user action was not registered")
}
args, _ := structpb.NewStruct(map[string]interface{}{
"resource_id": "u1",
})
result, _, err := capturedHandler(ctx, args)
if err != nil {
t.Fatalf("disable_user action failed: %v", err)
}
if !result.Fields["success"].GetBoolValue() {
t.Error("expected success=true")
}
})
}
func TestIntegration_ResourceActions(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
userSyncer := syncers[0]
type resourceActionProvider interface {
ResourceActions(ctx context.Context, registry actions.ActionRegistry) error
}
rap, ok := userSyncer.(resourceActionProvider)
if !ok {
t.Fatal("user builder does not implement ResourceActionProvider")
}
t.Run("update user profile", func(t *testing.T) {
var capturedHandler actions.ActionHandler
mockRegistry := &testActionRegistry{
registerFn: func(_ context.Context, schema *v2.BatonActionSchema, handler actions.ActionHandler) error {
if schema.GetName() == "update_profile" {
capturedHandler = handler
}
return nil
},
}
if err := rap.ResourceActions(ctx, mockRegistry); err != nil {
t.Fatalf("ResourceActions registration failed: %v", err)
}
if capturedHandler == nil {
t.Fatal("update_profile action was not registered")
}
userIDStruct, _ := structpb.NewStruct(map[string]interface{}{
"resource_type_id": "user",
"resource_id": "u1",
})
args := &structpb.Struct{
Fields: map[string]*structpb.Value{
"user_id": structpb.NewStructValue(userIDStruct),
"department": structpb.NewStringValue("Engineering"),
"first_name": structpb.NewStringValue("Alicia"),
},
}
result, _, err := capturedHandler(ctx, args)
if err != nil {
t.Fatalf("update_profile action failed: %v", err)
}
if !result.Fields["success"].GetBoolValue() {
t.Error("expected success=true")
}
if result.Fields["department"].GetStringValue() != "Engineering" {
t.Errorf("expected department Engineering, got %s", result.Fields["department"].GetStringValue())
}
if result.Fields["first_name"].GetStringValue() != "Alicia" {
t.Errorf("expected first_name Alicia, got %s", result.Fields["first_name"].GetStringValue())
}
})
}
func TestIntegration_GrantExpansion(t *testing.T) {
ts := mockAPI(t)
ctx := context.Background()
cb, _, err := connector.New(ctx, &cfg.App{AppClientId: "test-id", AppClientSecret: "test-secret", BaseUrl: ts.URL}, nil)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
syncers := cb.ResourceSyncers(ctx)
roleSyncer := syncers[2]
resources, _, err := roleSyncer.List(ctx, nil, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to list roles: %v", err)
}
findRole := func(id string) *v2.Resource {
for _, r := range resources {
if r.Id.Resource == id {
return r
}
}
return nil
}
t.Run("group-principal grant on Express", func(t *testing.T) {
expressRole := findRole("r5")
if expressRole == nil {
t.Fatal("Express role (r5) not found")
}
grants, _, err := roleSyncer.Grants(ctx, expressRole, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get Express grants: %v", err)
}
if len(grants) != 1 {
t.Fatalf("expected 1 grant for Express (group only), got %d", len(grants))
}
g := grants[0]
if g.Principal.Id.ResourceType != "group" {
t.Errorf("expected group principal, got %s", g.Principal.Id.ResourceType)
}
if g.Principal.Id.Resource != "g2" {
t.Errorf("expected principal g2, got %s", g.Principal.Id.Resource)
}
var expandable v2.GrantExpandable
annos := annotations.Annotations(g.Annotations)
if !annos.Contains(&expandable) {
t.Fatal("expected GrantExpandable annotation")
}
if _, err := annos.Pick(&expandable); err != nil {
t.Fatalf("failed to pick GrantExpandable: %v", err)
}
if len(expandable.EntitlementIds) != 1 || expandable.EntitlementIds[0] != "group:g2:member" {
t.Errorf("expected EntitlementIds [group:g2:member], got %v", expandable.EntitlementIds)
}
if expandable.Shallow {
t.Error("expected deep expansion (Shallow=false)")
}
})
t.Run("role-principal grant on Dining Access", func(t *testing.T) {
diningRole := findRole("r3")
if diningRole == nil {
t.Fatal("Dining Access role (r3) not found")
}
grants, _, err := roleSyncer.Grants(ctx, diningRole, resource.SyncOpAttrs{})
if err != nil {
t.Fatalf("failed to get Dining Access grants: %v", err)
}
var roleGrant *v2.Grant
for _, g := range grants {
if g.Principal.Id.ResourceType == "role" {
roleGrant = g
break
}
}
if roleGrant == nil {
t.Fatal("expected a role-principal grant on Dining Access")
}
if roleGrant.Principal.Id.Resource != "r4" {
t.Errorf("expected principal r4 (First Class), got %s", roleGrant.Principal.Id.Resource)
}
var expandable v2.GrantExpandable
annos := annotations.Annotations(roleGrant.Annotations)
if !annos.Contains(&expandable) {
t.Fatal("expected GrantExpandable annotation on role-principal grant")
}
if _, err := annos.Pick(&expandable); err != nil {
t.Fatalf("failed to pick GrantExpandable: %v", err)
}
if len(expandable.EntitlementIds) != 1 || expandable.EntitlementIds[0] != "role:r4:member" {
t.Errorf("expected EntitlementIds [role:r4:member], got %v", expandable.EntitlementIds)
}
if expandable.Shallow {
t.Error("expected deep expansion (Shallow=false)")
}
})
}
type testActionRegistry struct {
registerFn func(ctx context.Context, schema *v2.BatonActionSchema, handler actions.ActionHandler) error
}
func (r *testActionRegistry) Register(ctx context.Context, schema *v2.BatonActionSchema, handler actions.ActionHandler) error {
return r.registerFn(ctx, schema, handler)
}
func (r *testActionRegistry) RegisterAction(ctx context.Context, _ string, schema *v2.BatonActionSchema, handler actions.ActionHandler) error {
return r.Register(ctx, schema, handler)
}