14538: Move cancel-on-first-error logic to contextGroup.
authorTom Clegg <tclegg@veritasgenetics.com>
Tue, 27 Nov 2018 21:44:45 +0000 (16:44 -0500)
committerTom Clegg <tclegg@veritasgenetics.com>
Wed, 28 Nov 2018 20:43:07 +0000 (15:43 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tclegg@veritasgenetics.com>

sdk/go/arvados/contextgroup.go [new file with mode: 0644]
sdk/go/arvados/fs_collection.go

diff --git a/sdk/go/arvados/contextgroup.go b/sdk/go/arvados/contextgroup.go
new file mode 100644 (file)
index 0000000..fa0de24
--- /dev/null
@@ -0,0 +1,95 @@
+package arvados
+
+import (
+       "context"
+       "sync"
+)
+
+// A contextGroup is a context-aware variation on sync.WaitGroup. It
+// provides a child context for the added funcs to use, so they can
+// exit early if another added func returns an error. Its Wait()
+// method returns the first error returned by any added func.
+//
+// Example:
+//
+//     err := errors.New("oops")
+//     cg := newContextGroup()
+//     defer cg.Cancel()
+//     cg.Go(func() error {
+//             someFuncWithContext(cg.Context())
+//             return nil
+//     })
+//     cg.Go(func() error {
+//             return err // this cancels cg.Context()
+//     })
+//     return cg.Wait() // returns err after both goroutines have ended
+type contextGroup struct {
+       ctx    context.Context
+       cancel context.CancelFunc
+       wg     sync.WaitGroup
+       err    error
+       mtx    sync.Mutex
+}
+
+// newContextGroup returns a new contextGroup. The caller must
+// eventually call the Cancel() method of the returned contextGroup.
+func newContextGroup(ctx context.Context) *contextGroup {
+       ctx, cancel := context.WithCancel(ctx)
+       return &contextGroup{
+               ctx:    ctx,
+               cancel: cancel,
+       }
+}
+
+// Cancel cancels the context group.
+func (cg *contextGroup) Cancel() {
+       cg.cancel()
+}
+
+// Context returns a context.Context which will be canceled when all
+// funcs have succeeded or one has failed.
+func (cg *contextGroup) Context() context.Context {
+       return cg.ctx
+}
+
+// Go calls f in a new goroutine. If f returns an error, the
+// contextGroup is canceled.
+//
+// If f notices cg.Context() is done, it should abandon further work
+// and return. In this case, f's return value will be ignored.
+func (cg *contextGroup) Go(f func() error) {
+       cg.mtx.Lock()
+       defer cg.mtx.Unlock()
+       if cg.err != nil {
+               return
+       }
+       cg.wg.Add(1)
+       go func() {
+               defer cg.wg.Done()
+               err := f()
+               cg.mtx.Lock()
+               defer cg.mtx.Unlock()
+               if err != nil && cg.err == nil {
+                       cg.err = err
+                       cg.cancel()
+               }
+       }()
+}
+
+// Wait waits for all added funcs to return, and returns the first
+// non-nil error.
+//
+// If the parent context is canceled before a func returns an error,
+// Wait returns the parent context's Err().
+//
+// Wait returns nil if all funcs return nil before the parent context
+// is canceled.
+func (cg *contextGroup) Wait() error {
+       cg.wg.Wait()
+       cg.mtx.Lock()
+       defer cg.mtx.Unlock()
+       if cg.err != nil {
+               return cg.err
+       }
+       return cg.ctx.Err()
+}
index 0a7f408f8f2bbff4e720f5cb5c4651b04fb1af36..58482142fcd1b21185aee96cfb624fc58db1c60a 100644 (file)
@@ -554,29 +554,22 @@ func (dn *dirnode) Child(name string, replace func(inode) (inode, error)) (inode
 // local persistent storage. Caller must have write lock on dn and the
 // named children.
 func (dn *dirnode) sync(ctx context.Context, names []string, throttle *throttle) error {
-       ctx, cancel := context.WithCancel(ctx)
-       defer cancel()
+       cg := newContextGroup(ctx)
+       defer cg.Cancel()
 
        type shortBlock struct {
                fn  *filenode
                idx int
        }
-       var pending []shortBlock
-       var pendingLen int
 
-       errors := make(chan error, 1)
-       var wg sync.WaitGroup
-       defer wg.Wait() // we have locks: unsafe to return until all goroutines finish
-
-       flush := func(sbs []shortBlock) {
-               defer wg.Done()
+       flush := func(sbs []shortBlock) error {
                if len(sbs) == 0 {
-                       return
+                       return nil
                }
                throttle.Acquire()
                defer throttle.Release()
-               if ctx.Err() != nil {
-                       return
+               if err := cg.Context().Err(); err != nil {
+                       return err
                }
                block := make([]byte, 0, maxBlockSize)
                for _, sb := range sbs {
@@ -584,11 +577,7 @@ func (dn *dirnode) sync(ctx context.Context, names []string, throttle *throttle)
                }
                locator, _, err := dn.fs.PutB(block)
                if err != nil {
-                       select {
-                       case errors <- err:
-                       default:
-                       }
-                       cancel()
+                       return err
                }
                off := 0
                for _, sb := range sbs {
@@ -603,8 +592,17 @@ func (dn *dirnode) sync(ctx context.Context, names []string, throttle *throttle)
                        off += len(data)
                        sb.fn.memsize -= int64(len(data))
                }
+               return nil
+       }
+
+       goFlush := func(sbs []shortBlock) {
+               cg.Go(func() error {
+                       return flush(sbs)
+               })
        }
 
+       var pending []shortBlock
+       var pendingLen int
        localLocator := map[string]string{}
        for _, name := range names {
                fn, ok := dn.inodes[name].(*filenode)
@@ -627,13 +625,11 @@ func (dn *dirnode) sync(ctx context.Context, names []string, throttle *throttle)
                                fn.segments[idx] = seg
                        case *memSegment:
                                if seg.Len() > maxBlockSize/2 {
-                                       wg.Add(1)
-                                       go flush([]shortBlock{{fn, idx}})
+                                       goFlush([]shortBlock{{fn, idx}})
                                        continue
                                }
                                if pendingLen+seg.Len() > maxBlockSize {
-                                       wg.Add(1)
-                                       go flush(pending)
+                                       goFlush(pending)
                                        pending = nil
                                        pendingLen = 0
                                }
@@ -644,19 +640,14 @@ func (dn *dirnode) sync(ctx context.Context, names []string, throttle *throttle)
                        }
                }
        }
-       wg.Add(1)
-       flush(pending)
-       go func() {
-               wg.Wait()
-               close(errors)
-       }()
-       return <-errors
+       goFlush(pending)
+       return cg.Wait()
 }
 
 // caller must have write lock.
 func (dn *dirnode) marshalManifest(ctx context.Context, prefix string, throttle *throttle) (string, error) {
-       ctx, cancel := context.WithCancel(ctx)
-       defer cancel()
+       cg := newContextGroup(ctx)
+       defer cg.Cancel()
 
        if len(dn.inodes) == 0 {
                if prefix == "." {
@@ -690,27 +681,18 @@ func (dn *dirnode) marshalManifest(ctx context.Context, prefix string, throttle
                }
        }
 
-       var wg sync.WaitGroup
-       errors := make(chan error, len(dirnames)+1)
        subdirs := make([]string, len(dirnames))
        rootdir := ""
        for i, name := range dirnames {
-               wg.Add(1)
-               go func(i int, name string) {
-                       defer wg.Done()
-                       var err error
-                       subdirs[i], err = dn.inodes[name].(*dirnode).marshalManifest(ctx, prefix+"/"+name, throttle)
-                       if err != nil {
-                               errors <- err
-                               cancel()
-                       }
-               }(i, name)
+               i, name := i, name
+               cg.Go(func() error {
+                       txt, err := dn.inodes[name].(*dirnode).marshalManifest(cg.Context(), prefix+"/"+name, throttle)
+                       subdirs[i] = txt
+                       return err
+               })
        }
 
-       wg.Add(1)
-       go func() {
-               defer wg.Done()
-
+       cg.Go(func() error {
                var streamLen int64
                type filepart struct {
                        name   string
@@ -720,10 +702,8 @@ func (dn *dirnode) marshalManifest(ctx context.Context, prefix string, throttle
 
                var fileparts []filepart
                var blocks []string
-               if err := dn.sync(ctx, names, throttle); err != nil {
-                       errors <- err
-                       cancel()
-                       return
+               if err := dn.sync(cg.Context(), names, throttle); err != nil {
+                       return err
                }
                for _, name := range filenames {
                        node := dn.inodes[name].(*filenode)
@@ -765,20 +745,15 @@ func (dn *dirnode) marshalManifest(ctx context.Context, prefix string, throttle
                        filetokens = append(filetokens, fmt.Sprintf("%d:%d:%s", s.offset, s.length, manifestEscape(s.name)))
                }
                if len(filetokens) == 0 {
-                       return
+                       return nil
                } else if len(blocks) == 0 {
                        blocks = []string{"d41d8cd98f00b204e9800998ecf8427e+0"}
                }
                rootdir = manifestEscape(prefix) + " " + strings.Join(blocks, " ") + " " + strings.Join(filetokens, " ") + "\n"
-       }()
-
-       wg.Wait()
-       select {
-       case err := <-errors:
-               return "", err
-       default:
-       }
-       return rootdir + strings.Join(subdirs, ""), nil
+               return nil
+       })
+       err := cg.Wait()
+       return rootdir + strings.Join(subdirs, ""), err
 }
 
 func (dn *dirnode) loadManifest(txt string) error {