diff --git a/slackdump.go b/slackdump.go index 512993dc..ff44c115 100644 --- a/slackdump.go +++ b/slackdump.go @@ -161,13 +161,15 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err return nil, err } - if err := sd.initClient(ctx, prov); err != nil { + if err := sd.initClient(ctx, prov, sd.cfg.forceEnterprise); err != nil { return nil, err } return sd, nil } +// initWorkspaceInfo gets from the API and sets the workspace information for +// the session. func (s *Session) initWorkspaceInfo(ctx context.Context, cl Slacker) error { info, err := cl.AuthTestContext(ctx) if err != nil { @@ -177,8 +179,12 @@ func (s *Session) initWorkspaceInfo(ctx context.Context, cl Slacker) error { return nil } -// initClient initialises the client with the provided auth.Provider. -func (s *Session) initClient(ctx context.Context, prov auth.Provider) error { +// initClient initialises the client that is appropriate for the current +// workspace. It will use the initialised auth.Provider for credentials. If +// forceEdge is true, it will use th edge client regardless of whether it +// detects the enterprise instance or not. If the client was set by the +// WithClient option, it will not override it. +func (s *Session) initClient(ctx context.Context, prov auth.Provider, forceEdge bool) error { if s.client != nil { // already initialised, probably through options. return s.initWorkspaceInfo(ctx, s.client) @@ -188,13 +194,15 @@ func (s *Session) initClient(ctx context.Context, prov auth.Provider) error { if err != nil { return err } + // initialising default client cl := slack.New(prov.SlackToken(), slack.OptionHTTPClient(httpcl)) if err := s.initWorkspaceInfo(ctx, cl); err != nil { return err } - if s.cfg.forceEnterprise || s.wspInfo.EnterpriseID != "" { + isEnterpriseWsp := s.wspInfo.EnterpriseID != "" + if forceEdge || isEnterpriseWsp { // replace the client with the edge client ecl, err := edge.NewWithInfo(s.wspInfo, prov) if err != nil { diff --git a/slackdump_test.go b/slackdump_test.go index a469141c..904b812d 100644 --- a/slackdump_test.go +++ b/slackdump_test.go @@ -4,14 +4,21 @@ import ( "context" "log" "math" + "net/http" "os" "testing" + "testing/fstest" "time" "github.com/rusq/fsadapter" + "github.com/rusq/slack" "github.com/rusq/slackdump/v3/auth" + "github.com/rusq/slackdump/v3/internal/edge" + "github.com/rusq/slackdump/v3/internal/mocks/mock_auth" "github.com/rusq/slackdump/v3/internal/network" + "github.com/rusq/slackdump/v3/logger" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" ) func Test_newLimiter(t *testing.T) { @@ -134,3 +141,119 @@ func openTempFS() fsadapter.FSCloser { } return fsc } + +func TestSession_initWorkspaceInfo(t *testing.T) { + ctx := context.Background() + t.Run("ok", func(t *testing.T) { + ctrl := gomock.NewController(t) + mc := NewmockClienter(ctrl) + mc.EXPECT().AuthTestContext(gomock.Any()).Return(&slack.AuthTestResponse{ + TeamID: "TEST", + }, nil) + s := Session{ + client: nil, // it should use the provided client + } + + err := s.initWorkspaceInfo(ctx, mc) + assert.NoError(t, err, "unexpected initialisation error") + }) + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + mc := NewmockClienter(ctrl) + mc.EXPECT().AuthTestContext(gomock.Any()).Return(nil, assert.AnError) + s := Session{ + client: nil, // it should use the provided client + } + err := s.initWorkspaceInfo(ctx, mc) + assert.Error(t, err, "expected error") + }) +} + +func TestSession_initClient(t *testing.T) { + // fakeSlackAPI contains fake endpoints for the slack API. + fakeSlackAPI := fstest.MapFS{ + "api/auth.test": &fstest.MapFile{ + Data: []byte(`{"ok":true,"url":"https:\/\/test.slack.com\/","team":"TEST","user":"test","team_id":"T123456","user_id":"U123456"}`), + Mode: 0644, + }, + } + fakeEnterpriseSlackAPI := fstest.MapFS{ + "api/auth.test": &fstest.MapFile{ + Data: []byte(`{"ok":true,"url":"https:\/\/test.slack.com\/","team":"TEST","user":"test","team_id":"T123456","user_id":"U123456","enterprise_id":"E123456"}`), + }, + } + + expectAuthTestFn := func(mc *mockClienter, enterpriseID string) { + mc.EXPECT().AuthTestContext(gomock.Any()).Return(&slack.AuthTestResponse{ + TeamID: "TEST", + EnterpriseID: enterpriseID, + }, nil) + } + t.Run("pre-initialised client", func(t *testing.T) { + ctrl := gomock.NewController(t) + mc := NewmockClienter(ctrl) + expectAuthTestFn(mc, "") // not an anterprise instance + s := Session{ + client: mc, + } + err := s.initClient(context.Background(), nil, false) + assert.NoError(t, err, "unexpected error") + assert.IsType(t, &mockClienter{}, s.client) + }) + t.Run("standard client", func(t *testing.T) { + // http client will return the file from the fakeAPIFS. + cl := http.Client{ + Transport: http.NewFileTransportFS(fakeSlackAPI), + } + + ctrl := gomock.NewController(t) + mprov := mock_auth.NewMockProvider(ctrl) + mprov.EXPECT().SlackToken().Return("xoxb-...") + mprov.EXPECT().HTTPClient().Return(&cl, nil) + + s := Session{ + client: nil, + log: logger.Default, + } + err := s.initClient(context.Background(), mprov, false) + assert.NoError(t, err, "unexpected error") + assert.IsType(t, &slack.Client{}, s.client) + }) + + t.Run("enterprise client", func(t *testing.T) { + cl := http.Client{ + Transport: http.NewFileTransportFS(fakeEnterpriseSlackAPI), + } + + ctrl := gomock.NewController(t) + mprov := mock_auth.NewMockProvider(ctrl) + mprov.EXPECT().SlackToken().Return("xoxb-...").Times(2) + mprov.EXPECT().HTTPClient().Return(&cl, nil).Times(2) + + s := Session{ + client: nil, + log: logger.Default, + } + err := s.initClient(context.Background(), mprov, false) + assert.NoError(t, err, "unexpected error") + assert.IsType(t, &edge.Wrapper{}, s.client) + }) + t.Run("forced enterprise client", func(t *testing.T) { + cl := http.Client{ + Transport: http.NewFileTransportFS(fakeSlackAPI), + } + + ctrl := gomock.NewController(t) + mprov := mock_auth.NewMockProvider(ctrl) + mprov.EXPECT().SlackToken().Return("xoxb-...").Times(2) + mprov.EXPECT().HTTPClient().Return(&cl, nil).Times(2) + + s := Session{ + client: nil, + log: logger.Default, + } + err := s.initClient(context.Background(), mprov, true) + assert.NoError(t, err, "unexpected error") + assert.IsType(t, &edge.Wrapper{}, s.client) + }) +}