Fix some tests.
[lightning.git] / filter.go
index 89aa778f8bf5e0b9c03b9e8bfe45fc0b7fc709c3..4c86c1b85b486f6b9bf2e6961ed6ce9606ae7c08 100644 (file)
--- a/filter.go
+++ b/filter.go
@@ -1,4 +1,8 @@
-package main
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package lightning
 
 import (
        "bufio"
@@ -11,16 +15,114 @@ import (
        "net/http"
        _ "net/http/pprof"
        "os"
+       "regexp"
+       "strings"
 
        "git.arvados.org/arvados.git/sdk/go/arvados"
        log "github.com/sirupsen/logrus"
 )
 
-type filterer struct {
+type filter struct {
+       MaxVariants int
+       MinCoverage float64
+       MaxTag      int
+       MatchGenome string
+}
+
+func (f *filter) Flags(flags *flag.FlagSet) {
+       flags.IntVar(&f.MaxVariants, "max-variants", -1, "drop tiles with more than `N` variants")
+       flags.Float64Var(&f.MinCoverage, "min-coverage", 0, "drop tiles with coverage less than `P` across all haplotypes (0 < P ≤ 1)")
+       flags.IntVar(&f.MaxTag, "max-tag", -1, "drop tiles with tag ID > `N`")
+       flags.StringVar(&f.MatchGenome, "match-genome", "", "keep genomes whose names contain `regexp`, drop the rest")
+}
+
+func (f *filter) Args() []string {
+       return []string{
+               fmt.Sprintf("-max-variants=%d", f.MaxVariants),
+               fmt.Sprintf("-min-coverage=%f", f.MinCoverage),
+               fmt.Sprintf("-max-tag=%d", f.MaxTag),
+               fmt.Sprintf("-match-genome=%s", f.MatchGenome),
+       }
+}
+
+func (f *filter) Apply(tilelib *tileLibrary) {
+       // Zero out variants at tile positions that have more than
+       // f.MaxVariants tile variants.
+       if f.MaxVariants >= 0 {
+               for tag, variants := range tilelib.variant {
+                       if f.MaxTag >= 0 && tag >= f.MaxTag {
+                               break
+                       }
+                       if len(variants) <= f.MaxVariants {
+                               continue
+                       }
+                       for _, cg := range tilelib.compactGenomes {
+                               if len(cg) > tag*2 {
+                                       cg[tag*2] = 0
+                                       cg[tag*2+1] = 0
+                               }
+                       }
+               }
+       }
+
+       // Zero out variants at tile positions that have less than
+       // f.MinCoverage.
+       mincov := int(2*f.MinCoverage*float64(len(tilelib.compactGenomes)) + 1)
+TAG:
+       for tag := 0; tag < len(tilelib.variant) && (tag < f.MaxTag || f.MaxTag < 0); tag++ {
+               tagcov := 0
+               for _, cg := range tilelib.compactGenomes {
+                       if len(cg) < tag*2+2 {
+                               continue
+                       }
+                       if cg[tag*2] > 0 {
+                               tagcov++
+                       }
+                       if cg[tag*2+1] > 0 {
+                               tagcov++
+                       }
+                       if tagcov >= mincov {
+                               continue TAG
+                       }
+               }
+               for _, cg := range tilelib.compactGenomes {
+                       if len(cg) > tag*2 {
+                               cg[tag*2] = 0
+                               cg[tag*2+1] = 0
+                       }
+               }
+       }
+
+       // Truncate genomes and tile data to f.MaxTag (TODO: truncate
+       // refseqs too)
+       if f.MaxTag >= 0 {
+               if len(tilelib.variant) > f.MaxTag {
+                       tilelib.variant = tilelib.variant[:f.MaxTag]
+               }
+               for name, cg := range tilelib.compactGenomes {
+                       if len(cg) > 2*f.MaxTag {
+                               tilelib.compactGenomes[name] = cg[:2*f.MaxTag]
+                       }
+               }
+       }
+
+       re, err := regexp.Compile(f.MatchGenome)
+       if err != nil {
+               log.Errorf("invalid regexp %q does not match anything, dropping all genomes", f.MatchGenome)
+       }
+       for name := range tilelib.compactGenomes {
+               if !re.MatchString(name) {
+                       delete(tilelib.compactGenomes, name)
+               }
+       }
+}
+
+type filtercmd struct {
        output io.Writer
+       filter
 }
 
-func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
+func (cmd *filtercmd) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
        var err error
        defer func() {
                if err != nil {
@@ -35,15 +137,16 @@ func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, std
        priority := flags.Int("priority", 500, "container request priority")
        inputFilename := flags.String("i", "-", "input `file`")
        outputFilename := flags.String("o", "-", "output `file`")
-       maxvariants := flags.Int("max-variants", -1, "drop tiles with more than `N` variants")
-       mincoverage := flags.Float64("min-coverage", 0, "drop tiles with coverage less than `P` across all haplotypes (0 < P ≤ 1)")
-       maxtag := flags.Int("max-tag", -1, "drop tiles with tag ID > `N`")
+       cmd.filter.Flags(flags)
        err = flags.Parse(args)
        if err == flag.ErrHelp {
                err = nil
                return 0
        } else if err != nil {
                return 2
+       } else if flags.NArg() > 0 {
+               err = fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
+               return 2
        }
        cmd.output = stdout
 
@@ -73,9 +176,9 @@ func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, std
                runner.Args = []string{"filter", "-local=true",
                        "-i", *inputFilename,
                        "-o", "/mnt/output/library.gob",
-                       "-max-variants", fmt.Sprintf("%d", *maxvariants),
-                       "-min-coverage", fmt.Sprintf("%f", *mincoverage),
-                       "-max-tag", fmt.Sprintf("%d", *maxtag),
+                       "-max-variants", fmt.Sprintf("%d", cmd.MaxVariants),
+                       "-min-coverage", fmt.Sprintf("%f", cmd.MinCoverage),
+                       "-max-tag", fmt.Sprintf("%d", cmd.MaxTag),
                }
                var output string
                output, err = runner.Run()
@@ -97,7 +200,7 @@ func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, std
                defer infile.Close()
        }
        log.Print("reading")
-       cgs, err := ReadCompactGenomes(infile)
+       cgs, err := ReadCompactGenomes(infile, strings.HasSuffix(*inputFilename, ".gz"))
        if err != nil {
                return 1
        }
@@ -113,12 +216,11 @@ func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, std
                if ntags < len(cg.Variants)/2 {
                        ntags = len(cg.Variants) / 2
                }
-               if *maxvariants < 0 {
+               if cmd.MaxVariants < 0 {
                        continue
                }
-               maxVariantID := tileVariantID(*maxvariants)
                for idx, variant := range cg.Variants {
-                       if variant > maxVariantID {
+                       if variant > tileVariantID(cmd.MaxVariants) {
                                for _, cg := range cgs {
                                        if len(cg.Variants) > idx {
                                                cg.Variants[idx & ^1] = 0
@@ -129,17 +231,17 @@ func (cmd *filterer) RunCommand(prog string, args []string, stdin io.Reader, std
                }
        }
 
-       if *maxtag >= 0 && ntags > *maxtag {
-               ntags = *maxtag
+       if cmd.MaxTag >= 0 && ntags > cmd.MaxTag {
+               ntags = cmd.MaxTag
                for i, cg := range cgs {
-                       if len(cg.Variants) > *maxtag*2 {
-                               cgs[i].Variants = cg.Variants[:*maxtag*2]
+                       if len(cg.Variants) > cmd.MaxTag*2 {
+                               cgs[i].Variants = cg.Variants[:cmd.MaxTag*2]
                        }
                }
        }
 
-       if *mincoverage > 0 {
-               mincov := int(*mincoverage * float64(len(cgs)*2))
+       if cmd.MinCoverage > 0 {
+               mincov := int(cmd.MinCoverage * float64(len(cgs)*2))
                cov := make([]int, ntags)
                for _, cg := range cgs {
                        for idx, variant := range cg.Variants {