17574: Batch updates into transactions, skip when unchanged.
[arvados.git] / services / keep-balance / collection.go
index d67f2f0f090b173e16aa9e2a8b1b9094abc668fb..daedeb8bfcb82b94fbf3007fd1f1a13fe19df432 100644 (file)
@@ -1,17 +1,27 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
 package main
 
 import (
+       "context"
+       "encoding/json"
        "fmt"
+       "runtime"
+       "sync"
+       "sync/atomic"
        "time"
 
-       "git.curoverse.com/arvados.git/sdk/go/arvados"
+       "git.arvados.org/arvados.git/sdk/go/arvados"
+       "github.com/jmoiron/sqlx"
 )
 
 func countCollections(c *arvados.Client, params arvados.ResourceListParams) (int, error) {
        var page arvados.CollectionList
        var zero int
        params.Limit = &zero
-       params.Count = true
+       params.Count = "exact"
        err := c.RequestAndDecode(&page, "GET", "arvados/v1/collections", nil, params)
        return page.ItemsAvailable, err
 }
@@ -23,82 +33,229 @@ func countCollections(c *arvados.Client, params arvados.ResourceListParams) (int
 // The progress function is called periodically with done (number of
 // times f has been called) and total (number of times f is expected
 // to be called).
-//
-// If pageSize > 0 it is used as the maximum page size in each API
-// call; otherwise the maximum allowed page size is requested.
-func EachCollection(c *arvados.Client, pageSize int, f func(arvados.Collection) error, progress func(done, total int)) error {
+func EachCollection(ctx context.Context, db *sqlx.DB, c *arvados.Client, f func(arvados.Collection) error, progress func(done, total int)) error {
        if progress == nil {
                progress = func(_, _ int) {}
        }
 
-       expectCount, err := countCollections(c, arvados.ResourceListParams{})
+       expectCount, err := countCollections(c, arvados.ResourceListParams{
+               IncludeTrash:       true,
+               IncludeOldVersions: true,
+       })
        if err != nil {
                return err
        }
+       var newestModifiedAt time.Time
 
-       limit := pageSize
-       if limit <= 0 {
-               // Use the maximum page size the server allows
-               limit = 1<<31 - 1
-       }
-       params := arvados.ResourceListParams{
-               Limit:  &limit,
-               Order:  "modified_at, uuid",
-               Count: false,
-               Select: []string{"uuid", "manifest_text", "modified_at", "portable_data_hash", "replication_desired"},
+       rows, err := db.QueryxContext(ctx, `SELECT
+               uuid, manifest_text, modified_at, portable_data_hash,
+               replication_desired, replication_confirmed, replication_confirmed_at,
+               storage_classes_desired, storage_classes_confirmed, storage_classes_confirmed_at,
+               is_trashed
+               FROM collections`)
+       if err != nil {
+               return err
        }
-       var last arvados.Collection
-       var filterTime time.Time
+       defer rows.Close()
+       progressTicker := time.NewTicker(10 * time.Second)
+       defer progressTicker.Stop()
        callCount := 0
-       for {
-               progress(callCount, expectCount)
-               var page arvados.CollectionList
-               err := c.RequestAndDecode(&page, "GET", "arvados/v1/collections", nil, params)
+       for rows.Next() {
+               var coll arvados.Collection
+               var classesDesired, classesConfirmed []byte
+               err = rows.Scan(&coll.UUID, &coll.ManifestText, &coll.ModifiedAt, &coll.PortableDataHash,
+                       &coll.ReplicationDesired, &coll.ReplicationConfirmed, &coll.ReplicationConfirmedAt,
+                       &classesDesired, &classesConfirmed, &coll.StorageClassesConfirmedAt,
+                       &coll.IsTrashed)
                if err != nil {
                        return err
                }
-               for _, coll := range page.Items {
-                       if last.ModifiedAt != nil && *last.ModifiedAt == *coll.ModifiedAt && last.UUID >= coll.UUID {
-                               continue
-                       }
-                       callCount++
-                       err = f(coll)
-                       if err != nil {
-                               return err
-                       }
-                       last = coll
+
+               err = json.Unmarshal(classesDesired, &coll.StorageClassesDesired)
+               if err != nil && len(classesDesired) > 0 {
+                       return err
                }
-               if last.ModifiedAt == nil || *last.ModifiedAt == filterTime {
-                       if page.ItemsAvailable > len(page.Items) {
-                               // TODO: use "mtime=X && UUID>Y"
-                               // filters to get all collections with
-                               // this timestamp, then use "mtime>X"
-                               // to get the next timestamp.
-                               return fmt.Errorf("BUG: Received an entire page with the same modified_at timestamp (%v), cannot make progress", filterTime)
-                       }
-                       break
+               err = json.Unmarshal(classesConfirmed, &coll.StorageClassesConfirmed)
+               if err != nil && len(classesConfirmed) > 0 {
+                       return err
+               }
+               if newestModifiedAt.IsZero() || newestModifiedAt.Before(coll.ModifiedAt) {
+                       newestModifiedAt = coll.ModifiedAt
+               }
+               callCount++
+               err = f(coll)
+               if err != nil {
+                       return err
+               }
+               select {
+               case <-progressTicker.C:
+                       progress(callCount, expectCount)
+               default:
                }
-               filterTime = *last.ModifiedAt
-               params.Filters = []arvados.Filter{{
-                       Attr:     "modified_at",
-                       Operator: ">=",
-                       Operand:  filterTime,
-               }, {
-                       Attr:     "uuid",
-                       Operator: "!=",
-                       Operand:  last.UUID,
-               }}
        }
        progress(callCount, expectCount)
-
-       if checkCount, err := countCollections(c, arvados.ResourceListParams{Filters: []arvados.Filter{{
-               Attr:     "modified_at",
-               Operator: "<=",
-               Operand:  filterTime}}}); err != nil {
+       err = rows.Close()
+       if err != nil {
+               return err
+       }
+       if checkCount, err := countCollections(c, arvados.ResourceListParams{
+               Filters: []arvados.Filter{{
+                       Attr:     "modified_at",
+                       Operator: "<=",
+                       Operand:  newestModifiedAt}},
+               IncludeTrash:       true,
+               IncludeOldVersions: true,
+       }); err != nil {
                return err
        } else if callCount < checkCount {
-               return fmt.Errorf("Retrieved %d collections with modtime <= T=%q, but server now reports there are %d collections with modtime <= T", callCount, filterTime, checkCount)
+               return fmt.Errorf("Retrieved %d collections with modtime <= T=%q, but server now reports there are %d collections with modtime <= T", callCount, newestModifiedAt, checkCount)
        }
 
        return nil
 }
+
+func (bal *Balancer) updateCollections(ctx context.Context, c *arvados.Client, cluster *arvados.Cluster) error {
+       ctx, cancel := context.WithCancel(ctx)
+       defer cancel()
+
+       defer bal.time("update_collections", "wall clock time to update collections")()
+       threshold := time.Now()
+       thresholdStr := threshold.Format(time.RFC3339Nano)
+
+       updated := int64(0)
+
+       errs := make(chan error, 1)
+       collQ := make(chan arvados.Collection, cluster.Collections.BalanceCollectionBuffers)
+       go func() {
+               defer close(collQ)
+               err := EachCollection(ctx, bal.DB, c, func(coll arvados.Collection) error {
+                       if atomic.LoadInt64(&updated) >= int64(cluster.Collections.BalanceUpdateLimit) {
+                               bal.logf("reached BalanceUpdateLimit (%d)", cluster.Collections.BalanceUpdateLimit)
+                               cancel()
+                               return context.Canceled
+                       }
+                       collQ <- coll
+                       return nil
+               }, func(done, total int) {
+                       bal.logf("update collections: %d/%d (%d updated @ %.01f updates/s)", done, total, atomic.LoadInt64(&updated), float64(atomic.LoadInt64(&updated))/time.Since(threshold).Seconds())
+               })
+               if err != nil && err != context.Canceled {
+                       select {
+                       case errs <- err:
+                       default:
+                       }
+               }
+       }()
+
+       var wg sync.WaitGroup
+
+       // Use about 1 goroutine per 2 CPUs. Based on experiments with
+       // a 2-core host, using more concurrent database
+       // calls/transactions makes this process slower, not faster.
+       for i := 0; i < runtime.NumCPU()+1/2; i++ {
+               wg.Add(1)
+               goSendErr(errs, func() error {
+                       defer wg.Done()
+                       tx, err := bal.DB.Beginx()
+                       if err != nil {
+                               return err
+                       }
+                       txPending := 0
+                       flush := func(final bool) error {
+                               err := tx.Commit()
+                               if err != nil {
+                                       tx.Rollback()
+                                       return err
+                               }
+                               txPending = 0
+                               if final {
+                                       return nil
+                               }
+                               tx, err = bal.DB.Beginx()
+                               return err
+                       }
+                       txBatch := 100
+                       for coll := range collQ {
+                               if ctx.Err() != nil || len(errs) > 0 {
+                                       continue
+                               }
+                               blkids, err := coll.SizedDigests()
+                               if err != nil {
+                                       bal.logf("%s: %s", coll.UUID, err)
+                                       continue
+                               }
+                               repl := bal.BlockStateMap.GetConfirmedReplication(blkids, coll.StorageClassesDesired)
+
+                               desired := bal.DefaultReplication
+                               if coll.ReplicationDesired != nil {
+                                       desired = *coll.ReplicationDesired
+                               }
+                               if repl > desired {
+                                       // If actual>desired, confirm
+                                       // the desired number rather
+                                       // than actual to avoid
+                                       // flapping updates when
+                                       // replication increases
+                                       // temporarily.
+                                       repl = desired
+                               }
+                               classes, err := json.Marshal(coll.StorageClassesDesired)
+                               if err != nil {
+                                       bal.logf("BUG? json.Marshal(%v) failed: %s", classes, err)
+                                       continue
+                               }
+                               needUpdate := coll.ReplicationConfirmed == nil || *coll.ReplicationConfirmed != repl || len(coll.StorageClassesConfirmed) != len(coll.StorageClassesDesired)
+                               for i := range coll.StorageClassesDesired {
+                                       if !needUpdate && coll.StorageClassesDesired[i] != coll.StorageClassesConfirmed[i] {
+                                               needUpdate = true
+                                       }
+                               }
+                               if !needUpdate {
+                                       continue
+                               }
+                               _, err = tx.ExecContext(ctx, `update collections set
+                                       replication_confirmed=$1,
+                                       replication_confirmed_at=$2,
+                                       storage_classes_confirmed=$3,
+                                       storage_classes_confirmed_at=$2
+                                       where uuid=$4`,
+                                       repl, thresholdStr, classes, coll.UUID)
+                               if err != nil {
+                                       if err != context.Canceled {
+                                               bal.logf("%s: update failed: %s", coll.UUID, err)
+                                       }
+                                       continue
+                               }
+                               atomic.AddInt64(&updated, 1)
+                               if txPending++; txPending >= txBatch {
+                                       err = flush(false)
+                                       if err != nil {
+                                               return err
+                                       }
+                               }
+                       }
+                       return flush(true)
+               })
+       }
+       wg.Wait()
+       bal.logf("updated %d collections", updated)
+       if err := <-errs; err != nil {
+               return fmt.Errorf("error updating collections: %s", err)
+       }
+       return nil
+}
+
+// Call f in a new goroutine. If it returns a non-nil error, send the
+// error to the errs channel (unless the channel is already full with
+// another error).
+func goSendErr(errs chan<- error, f func() error) {
+       go func() {
+               err := f()
+               if err != nil {
+                       select {
+                       case errs <- err:
+                       default:
+                       }
+               }
+       }()
+}