Skip to content

Commit

Permalink
race condition and some other stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Apr 27, 2023
1 parent 57eac5d commit 547a821
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 79 deletions.
6 changes: 5 additions & 1 deletion cmd/slackdump/internal/dump/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ func dumpv3_2(ctx context.Context, sess *slackdump.Session, fsa fsadapter.FS, li
if err != nil {
return fmt.Errorf("failed to create conversation processor: %w", err)
}
defer proc.Close()
defer func() {
if err := proc.Close(); err != nil {
lg.Printf("failed to close conversation processor: %v", err)
}
}()

if err := sess.Stream().AsyncConversations(ctx, proc, list.Generator(ctx), func(sr slackdump.StreamResult) error {
if sr.Err != nil {
Expand Down
2 changes: 2 additions & 0 deletions cmd/slackdump/internal/export/expproc/conversations.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ func (cv *Conversations) ChannelInfo(ctx context.Context, ci *slack.Channel, isT
}

func (cv *Conversations) recorder(channelID string) (*baseproc, error) {
cv.mu.RLock()
r, ok := cv.cw[channelID]
cv.mu.RUnlock()
if ok {
return r.baseproc, nil
}
Expand Down
11 changes: 10 additions & 1 deletion internal/chunk/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,16 @@ func allForOffsets[T any](p *File, offsets []int64, fn func(c *Chunk) []T) ([]T,

// ChannelInfo returns the information for the given channel.
func (f *File) ChannelInfo(channelID string) (*slack.Channel, error) {
ofs, ok := f.Offsets(channelInfoID(channelID, false))
return f.channelInfo(channelID, false)
}

// ThreadInfo returns the channel information for the given thread.
func (f *File) ThreadInfo(channelID, threadTS string) (*slack.Channel, error) {
return f.channelInfo(channelID, true)
}

func (f *File) channelInfo(channelID string, thread bool) (*slack.Channel, error) {
ofs, ok := f.Offsets(channelInfoID(channelID, thread))
if !ok {
return nil, ErrNotFound
}
Expand Down
10 changes: 9 additions & 1 deletion internal/chunk/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,15 @@ func (p *Player) Reset() error {
// ChannelInfo returns the channel information for the given channel. It
// returns an error if the channel is not found within the chunkfile.
func (p *Player) ChannelInfo(id string) (*slack.Channel, error) {
chunk, err := p.next(channelInfoID(id, false))
return p.channelInfo(id, false)
}

func (p *Player) ThreadChannelInfo(id string) (*slack.Channel, error) {
return p.channelInfo(id, true)
}

func (p *Player) channelInfo(id string, isThread bool) (*slack.Channel, error) {
chunk, err := p.next(channelInfoID(id, isThread))
if err != nil {
return nil, err
}
Expand Down
18 changes: 0 additions & 18 deletions internal/chunk/processor/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,3 @@ type Channels interface {
}

var _ Channels = new(chunk.Recorder)

type options struct {
dumpFiles bool
}

// Option is a functional option for the processor.
type Option func(*options)

// DumpFiles disables the file processing (enabled by default). It may be
// useful on enterprise workspaces where the file download may be monitored.
// See [#191]
//
// [#191]: https://github.com/rusq/slackdump/discussions/191#discussioncomment-4953235
func DumpFiles(b bool) Option {
return func(o *options) {
o.dumpFiles = b
}
}
20 changes: 19 additions & 1 deletion internal/chunk/processor/standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,30 @@ type Standard struct {
opts options
}

type options struct {
dumpFiles bool
}

// Option is a functional option for the processor.
type Option func(*options)

// DumpFiles disables the file processing (enabled by default). It may be
// useful on enterprise workspaces where the file download may be monitored.
// See [#191]
//
// [#191]: https://github.com/rusq/slackdump/discussions/191#discussioncomment-4953235
func DumpFiles(b bool) Option {
return func(o *options) {
o.dumpFiles = b
}
}

// NewStandard creates a new standard processor. It will write the output to
// the given writer. The downloader is used to download files. The directory
// is the directory where the files will be downloaded to. The options are
// functional options. See the NoFiles option.
func NewStandard(ctx context.Context, w io.Writer, sess downloader.Downloader, dir string, opts ...Option) (*Standard, error) {
opt := options{dumpFiles: false}
opt := options{dumpFiles: true}
for _, o := range opts {
o(&opt)
}
Expand Down
44 changes: 31 additions & 13 deletions internal/chunk/transform/standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ func loadState(st *state.State, basePath string) (io.ReadSeekCloser, error) {
}

type Standard2 struct {
cd *chunk.Directory
fsa fsadapter.FS
ids chan string
done chan struct{}
cd *chunk.Directory
fsa fsadapter.FS
idsC chan string
doneC chan struct{}
errC chan error
}

func NewStandard2(fsa fsadapter.FS, dir string) (*Standard2, error) {
Expand All @@ -231,33 +232,49 @@ func NewStandard2(fsa fsadapter.FS, dir string) (*Standard2, error) {
return nil, err
}
std := &Standard2{
cd: cd,
fsa: fsa,
ids: make(chan string),
done: make(chan struct{}),
cd: cd,
fsa: fsa,
idsC: make(chan string),
doneC: make(chan struct{}),
errC: make(chan error, 1),
}
go std.worker()
return std, nil
}

func (s *Standard2) worker() {
defer close(s.done)
for id := range s.ids {
defer close(s.errC)
for id := range s.idsC {
if err := stdConvert(s.fsa, s.cd, id); err != nil {
s.errC <- err
dlog.Printf("error converting %q: %v", id, err)
}
}
}

// Close closes the transformer.
func (s *Standard2) Close() error {
close(s.ids)
<-s.done
close(s.idsC)
for err := range s.errC {
if err != nil {
return err
}
}
return nil
}

func (s *Standard2) OnFinalise(ctx context.Context, channelID string) error {
s.ids <- channelID
select {
case err := <-s.errC:
return err
default:
}
select {
case <-ctx.Done():
return ctx.Err()
case s.idsC <- channelID:
// keep going
}
return nil
}

Expand All @@ -271,6 +288,7 @@ func stdConvert(fsa fsadapter.FS, cd *chunk.Directory, chID string) error {
if err != nil {
return err
}
// determine if this a thread.
mm, err := cf.AllMessages(chID)
if err != nil {
return err
Expand Down
82 changes: 39 additions & 43 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ func (cs *Stream) AsyncConversations(ctx context.Context, proc processor.Convers
defer task.End()

// create channels
chansC := make(chan channelRequest, msgChanSz)
threadsC := make(chan threadRequest, threadChanSz)
chansC := make(chan request, msgChanSz)
threadsC := make(chan request, threadChanSz)

resultsC := make(chan StreamResult, resultSz)

Expand All @@ -164,7 +164,7 @@ func (cs *Stream) AsyncConversations(ctx context.Context, proc processor.Convers
defer wg.Done()
cs.channelWorker(ctx, proc, resultsC, threadsC, chansC)
// we close threads here, instead of the main loop, because we want to
// close it after all the thread workers are done.
// close it after all the threads are sent by channels.
close(threadsC)
trace.Log(ctx, "async", "channel worker done")
}()
Expand Down Expand Up @@ -223,7 +223,7 @@ func (cs *Stream) AsyncConversations(ctx context.Context, proc processor.Convers
}

// processLink parses the link and sends it to the appropriate worker.
func (cs *Stream) processLink(chans chan<- channelRequest, threads chan<- threadRequest, link string) error {
func (cs *Stream) processLink(chans chan<- request, threads chan<- request, link string) error {
sl, err := structures.ParseLink(link)
if err != nil {
return err
Expand All @@ -232,24 +232,18 @@ func (cs *Stream) processLink(chans chan<- channelRequest, threads chan<- thread
return fmt.Errorf("invalid slack link: %s", link)
}
if sl.IsThread() {
threads <- threadRequest{channelID: sl.Channel, threadTS: sl.ThreadTS, needChanInfo: true}
threads <- request{sl: &sl, standalone: true}
} else {
chans <- channelRequest{channelID: sl.Channel}
chans <- request{sl: &sl}
}
return nil
}

type channelRequest struct {
channelID string
}

type threadRequest struct {
channelID string
threadTS string
// needChanInfo indicates whether the channel info is needed for the thread.
// This is true when we're fetching the standalone thread without the
// conversation.
needChanInfo bool
type request struct {
sl *structures.SlackLink
// standalone indicates that this is the thread directly requested by the
// user, and not a thread that was found in the channel.
standalone bool
}

func (we *StreamResult) Error() string {
Expand All @@ -260,7 +254,7 @@ func (we *StreamResult) Unwrap() error {
return we.Err
}

func (cs *Stream) channelWorker(ctx context.Context, proc processor.Conversations, results chan<- StreamResult, threadC chan<- threadRequest, reqs <-chan channelRequest) {
func (cs *Stream) channelWorker(ctx context.Context, proc processor.Conversations, results chan<- StreamResult, threadC chan<- request, reqs <-chan request) {
ctx, task := trace.NewTask(ctx, "channelWorker")
defer task.End()

Expand All @@ -273,21 +267,21 @@ func (cs *Stream) channelWorker(ctx context.Context, proc processor.Conversation
if !more {
return // channel closed
}
channel, err := cs.channelInfo(ctx, proc, req.channelID, false)
channel, err := cs.channelInfo(ctx, proc, req.sl.Channel, false)
if err != nil {
results <- StreamResult{Type: RTChannel, ChannelID: req.channelID, Err: err}
results <- StreamResult{Type: RTChannel, ChannelID: req.sl.Channel, Err: err}
}
last := false
threadCount := 0
if err := cs.channel(ctx, req.channelID, func(mm []slack.Message, isLast bool) error {
if err := cs.channel(ctx, req.sl.Channel, func(mm []slack.Message, isLast bool) error {
last = isLast
n, err := processChannelMessages(ctx, proc, threadC, channel, isLast, mm)
n, err := procChanMsg(ctx, proc, threadC, channel, isLast, mm)
threadCount = n
return err
}); err != nil {
results <- StreamResult{Type: RTChannel, ChannelID: req.channelID, Err: err}
results <- StreamResult{Type: RTChannel, ChannelID: req.sl.Channel, Err: err}
}
results <- StreamResult{Type: RTChannel, ChannelID: req.channelID, ThreadCount: threadCount, IsLast: last}
results <- StreamResult{Type: RTChannel, ChannelID: req.sl.Channel, ThreadCount: threadCount, IsLast: last}
}
}
}
Expand Down Expand Up @@ -339,7 +333,7 @@ func (cs *Stream) channel(ctx context.Context, id string, fn func(mm []slack.Mes
return nil
}

func (cs *Stream) threadWorker(ctx context.Context, proc processor.Conversations, results chan<- StreamResult, reqs <-chan threadRequest) {
func (cs *Stream) threadWorker(ctx context.Context, proc processor.Conversations, results chan<- StreamResult, threadReq <-chan request) {
ctx, task := trace.NewTask(ctx, "threadWorker")
defer task.End()

Expand All @@ -348,36 +342,38 @@ func (cs *Stream) threadWorker(ctx context.Context, proc processor.Conversations
case <-ctx.Done():
results <- StreamResult{Type: RTThread, Err: ctx.Err()}
return
case req, more := <-reqs:
case req, more := <-threadReq:
if !more {
return // channel closed
}
var channel = new(slack.Channel)
if req.needChanInfo {
if _, err := cs.channelInfo(ctx, proc, req.channelID, true); err != nil {
results <- StreamResult{Type: RTThread, ChannelID: req.channelID, ThreadTS: req.threadTS, Err: err}
if req.standalone {
var err error
if channel, err = cs.channelInfo(ctx, proc, req.sl.Channel, true); err != nil {
results <- StreamResult{Type: RTThread, ChannelID: req.sl.Channel, ThreadTS: req.sl.ThreadTS, Err: err}
continue
}
} else {
channel.ID = req.channelID
channel.ID = req.sl.Channel
}
var last bool
if err := cs.thread(ctx, req.channelID, req.threadTS, func(msgs []slack.Message, isLast bool) error {
if err := cs.thread(ctx, req.sl, func(msgs []slack.Message, isLast bool) error {
last = isLast
return processThreadMessages(ctx, proc, channel, req.threadTS, isLast, msgs)
return procThreadMsg(ctx, proc, channel, req.sl.ThreadTS, isLast, msgs)
}); err != nil {
results <- StreamResult{Type: RTThread, ChannelID: req.channelID, ThreadTS: req.threadTS, Err: err}
results <- StreamResult{Type: RTThread, ChannelID: req.sl.Channel, ThreadTS: req.sl.ThreadTS, Err: err}
}
results <- StreamResult{Type: RTThread, ChannelID: req.channelID, ThreadTS: req.threadTS, IsLast: last}
results <- StreamResult{Type: RTThread, ChannelID: req.sl.Channel, ThreadTS: req.sl.ThreadTS, IsLast: last}
}
}
}

func (cs *Stream) thread(ctx context.Context, id string, threadTS string, fn func(mm []slack.Message, isLast bool) error) error {
func (cs *Stream) thread(ctx context.Context, sl *structures.SlackLink, fn func(mm []slack.Message, isLast bool) error) error {
ctx, task := trace.NewTask(ctx, "thread")
defer task.End()

lg := logger.FromContext(ctx)
lg.Debugf("- getting: thread: id=%s, thread_ts=%s", id, threadTS)
lg.Debugf("- getting: %s", sl)

var cursor string
for {
Expand All @@ -388,8 +384,8 @@ func (cs *Stream) thread(ctx context.Context, id string, threadTS string, fn fun
if err := network.WithRetry(ctx, cs.limits.threads, cs.limits.tier.Tier3.Retries, func() error {
var apiErr error
msgs, hasmore, cursor, apiErr = cs.client.GetConversationRepliesContext(ctx, &slack.GetConversationRepliesParameters{
ChannelID: id,
Timestamp: threadTS,
ChannelID: sl.Channel,
Timestamp: sl.ThreadTS,
Cursor: cursor,
Limit: cs.limits.tier.Request.Replies,
Oldest: structures.FormatSlackTS(cs.oldest),
Expand Down Expand Up @@ -420,13 +416,13 @@ func (cs *Stream) thread(ctx context.Context, id string, threadTS string, fn fun
return nil
}

// processChannelMessages processes the messages in the channel and sends
// procChanMsg processes the messages in the channel and sends
// thread requests for the threads in the channel, if it discovers messages
// with threads. It returns thread count in the mm and error if any.
func processChannelMessages(ctx context.Context, proc processor.Conversations, threadC chan<- threadRequest, channel *slack.Channel, isLast bool, mm []slack.Message) (int, error) {
func procChanMsg(ctx context.Context, proc processor.Conversations, threadC chan<- request, channel *slack.Channel, isLast bool, mm []slack.Message) (int, error) {
lg := logger.FromContext(ctx)

var trs = make([]threadRequest, 0, len(mm))
var trs = make([]request, 0, len(mm))
for i := range mm {
// collecting threads to get their count. But we don't start
// processing them yet, before we send the messages with the number of
Expand All @@ -435,7 +431,7 @@ func processChannelMessages(ctx context.Context, proc processor.Conversations, t
// count, if it needs it.
if mm[i].Msg.ThreadTimestamp != "" && mm[i].Msg.SubType != "thread_broadcast" && mm[i].LatestReply != structures.NoRepliesLatestReply {
lg.Debugf("- message #%d/channel=%s,thread: id=%s, thread_ts=%s", i, channel.ID, mm[i].Timestamp, mm[i].Msg.ThreadTimestamp)
trs = append(trs, threadRequest{channelID: channel.ID, threadTS: mm[i].Msg.ThreadTimestamp})
trs = append(trs, request{sl: &structures.SlackLink{Channel: channel.ID, ThreadTS: mm[i].Msg.ThreadTimestamp}})
}
if len(mm[i].Files) > 0 {
if err := proc.Files(ctx, channel, mm[i], false, mm[i].Files); err != nil {
Expand All @@ -452,7 +448,7 @@ func processChannelMessages(ctx context.Context, proc processor.Conversations, t
return len(trs), nil
}

func processThreadMessages(ctx context.Context, proc processor.Conversations, channel *slack.Channel, threadTS string, isLast bool, msgs []slack.Message) error {
func procThreadMsg(ctx context.Context, proc processor.Conversations, channel *slack.Channel, threadTS string, isLast bool, msgs []slack.Message) error {
// extract files from thread messages
for _, m := range msgs[1:] {
if len(m.Files) > 0 {
Expand Down
Loading

0 comments on commit 547a821

Please sign in to comment.