package client
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"github.com/conductorone/baton-sdk/pkg/uhttp"
)
const defaultPageSize = 100
type Client struct {
httpClient *uhttp.BaseHttpClient
baseURL string
}
func New(ctx context.Context, clientID, clientSecret, baseURL string) (*Client, error) {
tokenURL, err := url.Parse(baseURL + "/oauth/token")
if err != nil {
return nil, fmt.Errorf("baton-junction: invalid base URL: %w", err)
}
creds := uhttp.NewOAuth2ClientCredentials(clientID, clientSecret, tokenURL, nil)
httpClient, err := creds.GetClient(ctx)
if err != nil {
return nil, fmt.Errorf("baton-junction: failed to create OAuth client: %w", err)
}
baseClient, err := uhttp.NewBaseHttpClientWithContext(ctx, httpClient)
if err != nil {
return nil, fmt.Errorf("baton-junction: failed to create HTTP client: %w", err)
}
return &Client{
httpClient: baseClient,
baseURL: baseURL,
}, nil
}
func (c *Client) GetCurrentUser(ctx context.Context) (*User, error) {
var response SingleResponse[User]
if err := c.doGet(ctx, "/api/users/me", "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to get current user: %w", err)
}
return &response.Data, nil
}
func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
var response SingleResponse[User]
if err := c.doGet(ctx, fmt.Sprintf("/api/users/%s", userID), "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to get user: %w", err)
}
return &response.Data, nil
}
func (c *Client) GetUsers(ctx context.Context, cursor string) ([]User, string, error) {
var response ListResponse[User]
if err := c.doGet(ctx, "/api/users", cursor, &response); err != nil {
return nil, "", fmt.Errorf("baton-junction: failed to list users: %w", err)
}
return response.Data, response.NextCursor, nil
}
func (c *Client) CreateUser(ctx context.Context, req *CreateUserRequest) (*User, error) {
u, err := url.Parse(c.baseURL + "/api/users")
if err != nil {
return nil, fmt.Errorf("baton-junction: invalid URL: %w", err)
}
httpReq, err := c.httpClient.NewRequest(ctx, http.MethodPost, u,
uhttp.WithJSONBody(req),
uhttp.WithAcceptJSONHeader(),
)
if err != nil {
return nil, fmt.Errorf("baton-junction: failed to build create user request: %w", err)
}
var response SingleResponse[User]
if _, err := c.httpClient.Do(httpReq, uhttp.WithJSONResponse(&response)); err != nil {
return nil, fmt.Errorf("baton-junction: failed to create user: %w", err)
}
return &response.Data, nil
}
func (c *Client) UpdateUser(ctx context.Context, userID string, attrs map[string]string) (*User, error) {
u, err := url.Parse(fmt.Sprintf("%s/api/users/%s", c.baseURL, userID))
if err != nil {
return nil, fmt.Errorf("baton-junction: invalid URL: %w", err)
}
httpReq, err := c.httpClient.NewRequest(ctx, http.MethodPatch, u,
uhttp.WithJSONBody(attrs),
uhttp.WithAcceptJSONHeader(),
)
if err != nil {
return nil, fmt.Errorf("baton-junction: failed to build update user request: %w", err)
}
var response SingleResponse[User]
if _, err := c.httpClient.Do(httpReq, uhttp.WithJSONResponse(&response)); err != nil {
return nil, fmt.Errorf("baton-junction: failed to update user: %w", err)
}
return &response.Data, nil
}
func (c *Client) GetRole(ctx context.Context, roleID string) (*Role, error) {
var response SingleResponse[Role]
if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s", roleID), "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to get role: %w", err)
}
return &response.Data, nil
}
func (c *Client) GetRoles(ctx context.Context, cursor string) ([]Role, string, error) {
var response ListResponse[Role]
if err := c.doGet(ctx, "/api/roles", cursor, &response); err != nil {
return nil, "", fmt.Errorf("baton-junction: failed to list roles: %w", err)
}
return response.Data, response.NextCursor, nil
}
func (c *Client) GetRoleMembers(ctx context.Context, roleID, cursor string) ([]Member, string, error) {
var response ListResponse[Member]
if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/members", roleID), cursor, &response); err != nil {
return nil, "", fmt.Errorf("baton-junction: failed to list role members: %w", err)
}
return response.Data, response.NextCursor, nil
}
func (c *Client) AddRoleMember(ctx context.Context, roleID, userID string) error {
return c.doPut(ctx, fmt.Sprintf("/api/roles/%s/members/%s", roleID, userID))
}
func (c *Client) RemoveRoleMember(ctx context.Context, roleID, userID string) error {
return c.doDelete(ctx, fmt.Sprintf("/api/roles/%s/members/%s", roleID, userID))
}
func (c *Client) GetGroup(ctx context.Context, groupID string) (*Group, error) {
var response SingleResponse[Group]
if err := c.doGet(ctx, fmt.Sprintf("/api/groups/%s", groupID), "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to get group: %w", err)
}
return &response.Data, nil
}
func (c *Client) GetGroups(ctx context.Context, cursor string) ([]Group, string, error) {
var response ListResponse[Group]
if err := c.doGet(ctx, "/api/groups", cursor, &response); err != nil {
return nil, "", fmt.Errorf("baton-junction: failed to list groups: %w", err)
}
return response.Data, response.NextCursor, nil
}
func (c *Client) GetGroupMembers(ctx context.Context, groupID, cursor string) ([]Member, string, error) {
var response ListResponse[Member]
if err := c.doGet(ctx, fmt.Sprintf("/api/groups/%s/members", groupID), cursor, &response); err != nil {
return nil, "", fmt.Errorf("baton-junction: failed to list group members: %w", err)
}
return response.Data, response.NextCursor, nil
}
func (c *Client) AddGroupMember(ctx context.Context, groupID, userID string) error {
return c.doPut(ctx, fmt.Sprintf("/api/groups/%s/members/%s", groupID, userID))
}
func (c *Client) RemoveGroupMember(ctx context.Context, groupID, userID string) error {
return c.doDelete(ctx, fmt.Sprintf("/api/groups/%s/members/%s", groupID, userID))
}
func (c *Client) GetRoleGroups(ctx context.Context, roleID string) ([]GroupAssignment, error) {
var response ListResponse[GroupAssignment]
if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/groups", roleID), "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to list role group assignments: %w", err)
}
return response.Data, nil
}
func (c *Client) GetRoleRoles(ctx context.Context, roleID string) ([]RoleAssignment, error) {
var response ListResponse[RoleAssignment]
if err := c.doGet(ctx, fmt.Sprintf("/api/roles/%s/roles", roleID), "", &response); err != nil {
return nil, fmt.Errorf("baton-junction: failed to list role-to-role assignments: %w", err)
}
return response.Data, nil
}
// doGet builds a paginated GET request and unmarshals the JSON response.
// The SDK's Do() handles body closing and HTTP error → gRPC status mapping.
func (c *Client) doGet(ctx context.Context, path, cursor string, response any) error {
u, err := url.Parse(c.baseURL + path)
if err != nil {
return fmt.Errorf("baton-junction: invalid URL: %w", err)
}
q := u.Query()
q.Set("limit", strconv.Itoa(defaultPageSize))
if cursor != "" {
q.Set("cursor", cursor)
}
u.RawQuery = q.Encode()
req, err := c.httpClient.NewRequest(ctx, http.MethodGet, u,
uhttp.WithAcceptJSONHeader(),
)
if err != nil {
return err
}
_, err = c.httpClient.Do(req, uhttp.WithJSONResponse(response))
return err
}
// doPut sends a PUT with no body (idempotent membership add).
func (c *Client) doPut(ctx context.Context, path string) error {
u, err := url.Parse(c.baseURL + path)
if err != nil {
return fmt.Errorf("baton-junction: invalid URL: %w", err)
}
req, err := c.httpClient.NewRequest(ctx, http.MethodPut, u,
uhttp.WithAcceptJSONHeader(),
)
if err != nil {
return fmt.Errorf("baton-junction: failed to build request: %w", err)
}
_, err = c.httpClient.Do(req)
return err
}
// doDelete sends a DELETE request. Returns the SDK-mapped error directly;
// callers that need to tolerate 404 should check for codes.NotFound.
func (c *Client) doDelete(ctx context.Context, path string) error {
u, err := url.Parse(c.baseURL + path)
if err != nil {
return fmt.Errorf("baton-junction: invalid URL: %w", err)
}
req, err := c.httpClient.NewRequest(ctx, http.MethodDelete, u,
uhttp.WithAcceptJSONHeader(),
)
if err != nil {
return fmt.Errorf("baton-junction: failed to build request: %w", err)
}
_, err = c.httpClient.Do(req)
return err
}