pkg/client/client.go
253 linesgo
package client

import (
	"context"
	"fmt"
	"net/http"
	"net/url"
	"strconv"

	"github.com/conductorone/baton-sdk/pkg/uhttp"
)

const defaultPageSize = 100

type Client struct {
	httpClient *uhttp.BaseHttpClient
	baseURL    string
}

func New(ctx context.Context, clientID, clientSecret, baseURL string) (*Client, error) {
	tokenURL, err := url.Parse(baseURL + "/oauth/token")
	if err != nil {
		return nil, fmt.Errorf("baton-junction: invalid base URL: %w", err)
	}

	creds := uhttp.NewOAuth2ClientCredentials(clientID, clientSecret, tokenURL, nil)
	httpClient, err := creds.GetClient(ctx)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to create OAuth client: %w", err)
	}

	baseClient, err := uhttp.NewBaseHttpClientWithContext(ctx, httpClient)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to create HTTP client: %w", err)
	}

	return &Client{
		httpClient: baseClient,
		baseURL:    baseURL,
	}, nil
}

func (c *Client) GetCurrentUser(ctx context.Context) (*User, error) {
	var response SingleResponse[User]
	if err := c.doGet(ctx, "/api/users/me", "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to get current user: %w", err)
	}
	return &response.Data, nil
}

func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
	var response SingleResponse[User]
	if err := c.doGet(ctx, fmt.Sprintf("/api/users/%s", userID), "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to get user: %w", err)
	}
	return &response.Data, nil
}

func (c *Client) GetUsers(ctx context.Context, cursor string) ([]User, string, error) {
	var response ListResponse[User]
	if err := c.doGet(ctx, "/api/users", cursor, &response); err != nil {
		return nil, "", fmt.Errorf("baton-junction: failed to list users: %w", err)
	}
	return response.Data, response.NextCursor, nil
}

func (c *Client) CreateUser(ctx context.Context, req *CreateUserRequest) (*User, error) {
	u, err := url.Parse(c.baseURL + "/api/users")
	if err != nil {
		return nil, fmt.Errorf("baton-junction: invalid URL: %w", err)
	}

	httpReq, err := c.httpClient.NewRequest(ctx, http.MethodPost, u,
		uhttp.WithJSONBody(req),
		uhttp.WithAcceptJSONHeader(),
	)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to build create user request: %w", err)
	}

	var response SingleResponse[User]
	if _, err := c.httpClient.Do(httpReq, uhttp.WithJSONResponse(&response)); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to create user: %w", err)
	}

	return &response.Data, nil
}

func (c *Client) UpdateUser(ctx context.Context, userID string, attrs map[string]string) (*User, error) {
	u, err := url.Parse(fmt.Sprintf("%s/api/users/%s", c.baseURL, userID))
	if err != nil {
		return nil, fmt.Errorf("baton-junction: invalid URL: %w", err)
	}

	httpReq, err := c.httpClient.NewRequest(ctx, http.MethodPatch, u,
		uhttp.WithJSONBody(attrs),
		uhttp.WithAcceptJSONHeader(),
	)
	if err != nil {
		return nil, fmt.Errorf("baton-junction: failed to build update user request: %w", err)
	}

	var response SingleResponse[User]
	if _, err := c.httpClient.Do(httpReq, uhttp.WithJSONResponse(&response)); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to update user: %w", err)
	}

	return &response.Data, nil
}

func (c *Client) GetRole(ctx context.Context, roleID string) (*Role, error) {
	var response SingleResponse[Role]
	if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s", roleID), "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to get role: %w", err)
	}
	return &response.Data, nil
}

func (c *Client) GetRoles(ctx context.Context, cursor string) ([]Role, string, error) {
	var response ListResponse[Role]
	if err := c.doGet(ctx, "/api/roles", cursor, &response); err != nil {
		return nil, "", fmt.Errorf("baton-junction: failed to list roles: %w", err)
	}
	return response.Data, response.NextCursor, nil
}

func (c *Client) GetRoleMembers(ctx context.Context, roleID, cursor string) ([]Member, string, error) {
	var response ListResponse[Member]
	if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/members", roleID), cursor, &response); err != nil {
		return nil, "", fmt.Errorf("baton-junction: failed to list role members: %w", err)
	}
	return response.Data, response.NextCursor, nil
}

func (c *Client) AddRoleMember(ctx context.Context, roleID, userID string) error {
	return c.doPut(ctx, fmt.Sprintf("/api/roles/%s/members/%s", roleID, userID))
}

func (c *Client) RemoveRoleMember(ctx context.Context, roleID, userID string) error {
	return c.doDelete(ctx, fmt.Sprintf("/api/roles/%s/members/%s", roleID, userID))
}

func (c *Client) GetGroup(ctx context.Context, groupID string) (*Group, error) {
	var response SingleResponse[Group]
	if err := c.doGet(ctx, fmt.Sprintf("/api/groups/%s", groupID), "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to get group: %w", err)
	}
	return &response.Data, nil
}

func (c *Client) GetGroups(ctx context.Context, cursor string) ([]Group, string, error) {
	var response ListResponse[Group]
	if err := c.doGet(ctx, "/api/groups", cursor, &response); err != nil {
		return nil, "", fmt.Errorf("baton-junction: failed to list groups: %w", err)
	}
	return response.Data, response.NextCursor, nil
}

func (c *Client) GetGroupMembers(ctx context.Context, groupID, cursor string) ([]Member, string, error) {
	var response ListResponse[Member]
	if err := c.doGet(ctx, fmt.Sprintf("/api/groups/%s/members", groupID), cursor, &response); err != nil {
		return nil, "", fmt.Errorf("baton-junction: failed to list group members: %w", err)
	}
	return response.Data, response.NextCursor, nil
}

func (c *Client) AddGroupMember(ctx context.Context, groupID, userID string) error {
	return c.doPut(ctx, fmt.Sprintf("/api/groups/%s/members/%s", groupID, userID))
}

func (c *Client) RemoveGroupMember(ctx context.Context, groupID, userID string) error {
	return c.doDelete(ctx, fmt.Sprintf("/api/groups/%s/members/%s", groupID, userID))
}

func (c *Client) GetRoleGroups(ctx context.Context, roleID string) ([]GroupAssignment, error) {
	var response ListResponse[GroupAssignment]
	if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/groups", roleID), "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to list role group assignments: %w", err)
	}
	return response.Data, nil
}

func (c *Client) GetRoleRoles(ctx context.Context, roleID string) ([]RoleAssignment, error) {
	var response ListResponse[RoleAssignment]
	if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/roles", roleID), "", &response); err != nil {
		return nil, fmt.Errorf("baton-junction: failed to list role-to-role assignments: %w", err)
	}
	return response.Data, nil
}

// doGet builds a paginated GET request and unmarshals the JSON response.
// The SDK's Do() handles body closing and HTTP error → gRPC status mapping.
func (c *Client) doGet(ctx context.Context, path, cursor string, response any) error {
	u, err := url.Parse(c.baseURL + path)
	if err != nil {
		return fmt.Errorf("baton-junction: invalid URL: %w", err)
	}

	q := u.Query()
	q.Set("limit", strconv.Itoa(defaultPageSize))
	if cursor != "" {
		q.Set("cursor", cursor)
	}
	u.RawQuery = q.Encode()

	req, err := c.httpClient.NewRequest(ctx, http.MethodGet, u,
		uhttp.WithAcceptJSONHeader(),
	)
	if err != nil {
		return err
	}

	_, err = c.httpClient.Do(req, uhttp.WithJSONResponse(response))
	return err
}

// doPut sends a PUT with no body (idempotent membership add).
func (c *Client) doPut(ctx context.Context, path string) error {
	u, err := url.Parse(c.baseURL + path)
	if err != nil {
		return fmt.Errorf("baton-junction: invalid URL: %w", err)
	}

	req, err := c.httpClient.NewRequest(ctx, http.MethodPut, u,
		uhttp.WithAcceptJSONHeader(),
	)
	if err != nil {
		return fmt.Errorf("baton-junction: failed to build request: %w", err)
	}

	_, err = c.httpClient.Do(req)
	return err
}

// doDelete sends a DELETE request. Returns the SDK-mapped error directly;
// callers that need to tolerate 404 should check for codes.NotFound.
func (c *Client) doDelete(ctx context.Context, path string) error {
	u, err := url.Parse(c.baseURL + path)
	if err != nil {
		return fmt.Errorf("baton-junction: invalid URL: %w", err)
	}

	req, err := c.httpClient.NewRequest(ctx, http.MethodDelete, u,
		uhttp.WithAcceptJSONHeader(),
	)
	if err != nil {
		return fmt.Errorf("baton-junction: failed to build request: %w", err)
	}

	_, err = c.httpClient.Do(req)
	return err
}