19566: Record number of p-value calculations performed.
[lightning.git] / slicenumpy.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bufio"
9         "bytes"
10         "encoding/gob"
11         "encoding/json"
12         "errors"
13         "flag"
14         "fmt"
15         "io"
16         "io/ioutil"
17         "math"
18         "net/http"
19         _ "net/http/pprof"
20         "os"
21         "regexp"
22         "runtime"
23         "runtime/debug"
24         "sort"
25         "strconv"
26         "strings"
27         "sync/atomic"
28         "unsafe"
29
30         "git.arvados.org/arvados.git/sdk/go/arvados"
31         "github.com/arvados/lightning/hgvs"
32         "github.com/james-bowman/nlp"
33         "github.com/kshedden/gonpy"
34         "github.com/sirupsen/logrus"
35         log "github.com/sirupsen/logrus"
36         "golang.org/x/crypto/blake2b"
37         "gonum.org/v1/gonum/mat"
38 )
39
40 const annotationMaxTileSpan = 100
41
42 type sliceNumpy struct {
43         filter          filter
44         threads         int
45         chi2Cases       []bool
46         chi2PValue      float64
47         pcaComponents   int
48         minCoverage     int
49         includeVariant1 bool
50         debugTag        tagID
51
52         cgnames         []string
53         samples         []sampleInfo
54         trainingSet     []int // samples index => training set index, or -1 if not in training set
55         trainingSetSize int
56         pvalue          func(onehot []bool) float64
57         pvalueCallCount int64
58 }
59
60 func (cmd *sliceNumpy) RunCommand(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) int {
61         err := cmd.run(prog, args, stdin, stdout, stderr)
62         if err != nil {
63                 fmt.Fprintf(stderr, "%s\n", err)
64                 return 1
65         }
66         return 0
67 }
68
69 func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
70         flags := flag.NewFlagSet("", flag.ContinueOnError)
71         flags.SetOutput(stderr)
72         pprof := flags.String("pprof", "", "serve Go profile data at http://`[addr]:port`")
73         runlocal := flags.Bool("local", false, "run on local host (default: run in an arvados container)")
74         arvadosRAM := flags.Int("arvados-ram", 750000000000, "amount of memory to request for arvados container (`bytes`)")
75         arvadosVCPUs := flags.Int("arvados-vcpus", 96, "number of VCPUs to request for arvados container")
76         projectUUID := flags.String("project", "", "project `UUID` for output data")
77         priority := flags.Int("priority", 500, "container request priority")
78         inputDir := flags.String("input-dir", "./in", "input `directory`")
79         outputDir := flags.String("output-dir", "./out", "output `directory`")
80         ref := flags.String("ref", "", "reference name (if blank, choose last one that appears in input)")
81         regionsFilename := flags.String("regions", "", "only output columns/annotations that intersect regions in specified bed `file`")
82         expandRegions := flags.Int("expand-regions", 0, "expand specified regions by `N` base pairs on each side`")
83         mergeOutput := flags.Bool("merge-output", false, "merge output into one matrix.npy and one matrix.annotations.csv")
84         hgvsSingle := flags.Bool("single-hgvs-matrix", false, "also generate hgvs-based matrix")
85         hgvsChunked := flags.Bool("chunked-hgvs-matrix", false, "also generate hgvs-based matrix per chromosome")
86         onehotSingle := flags.Bool("single-onehot", false, "generate one-hot tile-based matrix")
87         onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk")
88         samplesFilename := flags.String("samples", "", "`samples.csv` file with training/validation and case/control groups (see 'lightning choose-samples')")
89         caseControlOnly := flags.Bool("case-control-only", false, "drop samples that are not in case/control groups")
90         onlyPCA := flags.Bool("pca", false, "run principal component analysis, write components to pca.npy and samples.csv")
91         flags.IntVar(&cmd.pcaComponents, "pca-components", 4, "number of PCA components to compute / use in logistic regression")
92         maxPCATiles := flags.Int("max-pca-tiles", 0, "maximum tiles to use as PCA input (filter, then drop every 2nd colum pair until below max)")
93         debugTag := flags.Int("debug-tag", -1, "log debugging details about specified tag")
94         flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads, and number of VCPUs to request for arvados container")
95         flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test (or logistic regression if -samples file has PCA components) and omit columns with p-value above this threshold")
96         flags.BoolVar(&cmd.includeVariant1, "include-variant-1", false, "include most common variant when building one-hot matrix")
97         cmd.filter.Flags(flags)
98         err := flags.Parse(args)
99         if err == flag.ErrHelp {
100                 return nil
101         } else if err != nil {
102                 return err
103         } else if flags.NArg() > 0 {
104                 return fmt.Errorf("errant command line arguments after parsed flags: %v", flags.Args())
105         }
106
107         if *pprof != "" {
108                 go func() {
109                         log.Println(http.ListenAndServe(*pprof, nil))
110                 }()
111         }
112
113         if cmd.chi2PValue != 1 && *samplesFilename == "" {
114                 return fmt.Errorf("cannot use provided -chi2-p-value=%f because -samples= value is empty", cmd.chi2PValue)
115         }
116
117         cmd.debugTag = tagID(*debugTag)
118
119         if !*runlocal {
120                 runner := arvadosContainerRunner{
121                         Name:        "lightning slice-numpy",
122                         Client:      arvados.NewClientFromEnv(),
123                         ProjectUUID: *projectUUID,
124                         RAM:         int64(*arvadosRAM),
125                         VCPUs:       *arvadosVCPUs,
126                         Priority:    *priority,
127                         KeepCache:   2,
128                         APIAccess:   true,
129                 }
130                 err = runner.TranslatePaths(inputDir, regionsFilename, samplesFilename)
131                 if err != nil {
132                         return err
133                 }
134                 runner.Args = []string{"slice-numpy", "-local=true",
135                         "-pprof=:6060",
136                         "-input-dir=" + *inputDir,
137                         "-output-dir=/mnt/output",
138                         "-threads=" + fmt.Sprintf("%d", cmd.threads),
139                         "-regions=" + *regionsFilename,
140                         "-expand-regions=" + fmt.Sprintf("%d", *expandRegions),
141                         "-merge-output=" + fmt.Sprintf("%v", *mergeOutput),
142                         "-single-hgvs-matrix=" + fmt.Sprintf("%v", *hgvsSingle),
143                         "-chunked-hgvs-matrix=" + fmt.Sprintf("%v", *hgvsChunked),
144                         "-single-onehot=" + fmt.Sprintf("%v", *onehotSingle),
145                         "-chunked-onehot=" + fmt.Sprintf("%v", *onehotChunked),
146                         "-samples=" + *samplesFilename,
147                         "-case-control-only=" + fmt.Sprintf("%v", *caseControlOnly),
148                         "-pca=" + fmt.Sprintf("%v", *onlyPCA),
149                         "-pca-components=" + fmt.Sprintf("%d", cmd.pcaComponents),
150                         "-max-pca-tiles=" + fmt.Sprintf("%d", *maxPCATiles),
151                         "-chi2-p-value=" + fmt.Sprintf("%f", cmd.chi2PValue),
152                         "-include-variant-1=" + fmt.Sprintf("%v", cmd.includeVariant1),
153                         "-debug-tag=" + fmt.Sprintf("%d", cmd.debugTag),
154                 }
155                 runner.Args = append(runner.Args, cmd.filter.Args()...)
156                 var output string
157                 output, err = runner.Run()
158                 if err != nil {
159                         return err
160                 }
161                 fmt.Fprintln(stdout, output)
162                 return nil
163         }
164
165         infiles, err := allFiles(*inputDir, matchGobFile)
166         if err != nil {
167                 return err
168         }
169         if len(infiles) == 0 {
170                 err = fmt.Errorf("no input files found in %s", *inputDir)
171                 return err
172         }
173         sort.Strings(infiles)
174
175         var refseq map[string][]tileLibRef
176         var reftiledata = make(map[tileLibRef][]byte, 11000000)
177         in0, err := open(infiles[0])
178         if err != nil {
179                 return err
180         }
181
182         matchGenome, err := regexp.Compile(cmd.filter.MatchGenome)
183         if err != nil {
184                 err = fmt.Errorf("-match-genome: invalid regexp: %q", cmd.filter.MatchGenome)
185                 return err
186         }
187
188         if *samplesFilename != "" {
189                 cmd.samples, err = loadSampleInfo(*samplesFilename)
190                 if err != nil {
191                         return err
192                 }
193                 if len(cmd.samples[0].pcaComponents) > 0 {
194                         cmd.pvalue = glmPvalueFunc(cmd.samples, cmd.pcaComponents)
195                         // Unfortunately, statsmodel/glm lib logs
196                         // stuff to os.Stdout when it panics on an
197                         // unsolvable problem. We recover() from the
198                         // panic in glm.go, but we also need to
199                         // commandeer os.Stdout to avoid producing
200                         // large quantities of logs.
201                         stdoutWas := os.Stdout
202                         defer func() { os.Stdout = stdoutWas }()
203                         os.Stdout, err = os.Open(os.DevNull)
204                         if err != nil {
205                                 return err
206                         }
207                 }
208         } else if *caseControlOnly {
209                 return fmt.Errorf("-case-control-only does not make sense without -samples")
210         }
211
212         cmd.cgnames = nil
213         var tagset [][]byte
214         err = DecodeLibrary(in0, strings.HasSuffix(infiles[0], ".gz"), func(ent *LibraryEntry) error {
215                 if len(ent.TagSet) > 0 {
216                         tagset = ent.TagSet
217                 }
218                 for _, cseq := range ent.CompactSequences {
219                         if cseq.Name == *ref || *ref == "" {
220                                 refseq = cseq.TileSequences
221                         }
222                 }
223                 for _, cg := range ent.CompactGenomes {
224                         if matchGenome.MatchString(cg.Name) {
225                                 cmd.cgnames = append(cmd.cgnames, cg.Name)
226                         }
227                 }
228                 for _, tv := range ent.TileVariants {
229                         if tv.Ref {
230                                 reftiledata[tileLibRef{tv.Tag, tv.Variant}] = tv.Sequence
231                         }
232                 }
233                 return nil
234         })
235         if err != nil {
236                 return err
237         }
238         in0.Close()
239         if refseq == nil {
240                 err = fmt.Errorf("%s: reference sequence not found", infiles[0])
241                 return err
242         }
243         if len(tagset) == 0 {
244                 err = fmt.Errorf("tagset not found")
245                 return err
246         }
247
248         taglib := &tagLibrary{}
249         err = taglib.setTags(tagset)
250         if err != nil {
251                 return err
252         }
253         taglen := taglib.TagLen()
254         sort.Strings(cmd.cgnames)
255
256         if len(cmd.cgnames) == 0 {
257                 return fmt.Errorf("fatal: 0 matching samples in library, nothing to do")
258         }
259         cmd.trainingSet = make([]int, len(cmd.cgnames))
260         if *samplesFilename == "" {
261                 cmd.trainingSetSize = len(cmd.cgnames)
262                 for i, name := range cmd.cgnames {
263                         cmd.samples = append(cmd.samples, sampleInfo{
264                                 id:         trimFilenameForLabel(name),
265                                 isTraining: true,
266                         })
267                         cmd.trainingSet[i] = i
268                 }
269         } else if len(cmd.cgnames) != len(cmd.samples) {
270                 return fmt.Errorf("mismatched sample list: %d samples in library, %d in %s", len(cmd.cgnames), len(cmd.samples), *samplesFilename)
271         } else {
272                 for i, name := range cmd.cgnames {
273                         if s := trimFilenameForLabel(name); s != cmd.samples[i].id {
274                                 return fmt.Errorf("mismatched sample list: sample %d is %q in library, %q in %s", i, s, cmd.samples[i].id, *samplesFilename)
275                         }
276                 }
277                 if *caseControlOnly {
278                         for i := 0; i < len(cmd.samples); i++ {
279                                 if !cmd.samples[i].isTraining && !cmd.samples[i].isValidation {
280                                         if i+1 < len(cmd.samples) {
281                                                 copy(cmd.samples[i:], cmd.samples[i+1:])
282                                                 copy(cmd.cgnames[i:], cmd.cgnames[i+1:])
283                                         }
284                                         cmd.samples = cmd.samples[:len(cmd.samples)-1]
285                                         cmd.cgnames = cmd.cgnames[:len(cmd.cgnames)-1]
286                                         i--
287                                 }
288                         }
289                 }
290                 cmd.chi2Cases = nil
291                 cmd.trainingSetSize = 0
292                 for i := range cmd.cgnames {
293                         if cmd.samples[i].isTraining {
294                                 cmd.trainingSet[i] = cmd.trainingSetSize
295                                 cmd.trainingSetSize++
296                                 cmd.chi2Cases = append(cmd.chi2Cases, cmd.samples[i].isCase)
297                         } else {
298                                 cmd.trainingSet[i] = -1
299                         }
300                 }
301                 if cmd.pvalue == nil {
302                         cmd.pvalue = func(onehot []bool) float64 {
303                                 return pvalue(onehot, cmd.chi2Cases)
304                         }
305                 }
306         }
307         if cmd.filter.MinCoverage == 1 {
308                 // In the generic formula below, floating point
309                 // arithmetic can effectively push the coverage
310                 // threshold above 1.0, which is impossible/useless.
311                 // 1.0 needs to mean exactly 100% coverage.
312                 cmd.minCoverage = len(cmd.cgnames)
313         } else {
314                 cmd.minCoverage = int(math.Ceil(cmd.filter.MinCoverage * float64(len(cmd.cgnames))))
315         }
316
317         {
318                 samplesOutFilename := *outputDir + "/samples.csv"
319                 log.Infof("writing sample metadata to %s", samplesOutFilename)
320                 var f *os.File
321                 f, err = os.Create(samplesOutFilename)
322                 if err != nil {
323                         return err
324                 }
325                 defer f.Close()
326                 for i, si := range cmd.samples {
327                         var cc, tv string
328                         if si.isCase {
329                                 cc = "1"
330                         } else if si.isControl {
331                                 cc = "0"
332                         }
333                         if si.isTraining {
334                                 tv = "1"
335                         } else {
336                                 tv = "0"
337                         }
338                         _, err = fmt.Fprintf(f, "%d,%s,%s,%s\n", i, si.id, cc, tv)
339                         if err != nil {
340                                 err = fmt.Errorf("write %s: %w", samplesOutFilename, err)
341                                 return err
342                         }
343                 }
344                 err = f.Close()
345                 if err != nil {
346                         err = fmt.Errorf("close %s: %w", samplesOutFilename, err)
347                         return err
348                 }
349                 log.Print("done")
350         }
351
352         log.Info("indexing reference tiles")
353         type reftileinfo struct {
354                 variant  tileVariantID
355                 seqname  string // chr1
356                 pos      int    // distance from start of chromosome to starttag
357                 tiledata []byte // acgtggcaa...
358                 excluded bool   // true if excluded by regions file
359                 nexttag  tagID  // tagID of following tile (-1 for last tag of chromosome)
360         }
361         isdup := map[tagID]bool{}
362         reftile := map[tagID]*reftileinfo{}
363         for seqname, cseq := range refseq {
364                 pos := 0
365                 lastreftag := tagID(-1)
366                 for _, libref := range cseq {
367                         if cmd.filter.MaxTag >= 0 && libref.Tag > tagID(cmd.filter.MaxTag) {
368                                 continue
369                         }
370                         tiledata := reftiledata[libref]
371                         if len(tiledata) == 0 {
372                                 err = fmt.Errorf("missing tiledata for tag %d variant %d in %s in ref", libref.Tag, libref.Variant, seqname)
373                                 return err
374                         }
375                         foundthistag := false
376                         taglib.FindAll(tiledata[:len(tiledata)-1], func(tagid tagID, offset, _ int) {
377                                 if !foundthistag && tagid == libref.Tag {
378                                         foundthistag = true
379                                         return
380                                 }
381                                 if dupref, ok := reftile[tagid]; ok {
382                                         log.Printf("dropping reference tile %+v from %s @ %d, tag not unique, also found inside %+v from %s @ %d", tileLibRef{Tag: tagid, Variant: dupref.variant}, dupref.seqname, dupref.pos, libref, seqname, pos+offset+1)
383                                         delete(reftile, tagid)
384                                 } else {
385                                         log.Printf("found tag %d at offset %d inside tile variant %+v on %s @ %d", tagid, offset, libref, seqname, pos+offset+1)
386                                 }
387                                 isdup[tagid] = true
388                         })
389                         if isdup[libref.Tag] {
390                                 log.Printf("dropping reference tile %+v from %s @ %d, tag not unique", libref, seqname, pos)
391                         } else if reftile[libref.Tag] != nil {
392                                 log.Printf("dropping reference tile %+v from %s @ %d, tag not unique", tileLibRef{Tag: libref.Tag, Variant: reftile[libref.Tag].variant}, reftile[libref.Tag].seqname, reftile[libref.Tag].pos)
393                                 delete(reftile, libref.Tag)
394                                 log.Printf("dropping reference tile %+v from %s @ %d, tag not unique", libref, seqname, pos)
395                                 isdup[libref.Tag] = true
396                         } else {
397                                 reftile[libref.Tag] = &reftileinfo{
398                                         seqname:  seqname,
399                                         variant:  libref.Variant,
400                                         tiledata: tiledata,
401                                         pos:      pos,
402                                         nexttag:  -1,
403                                 }
404                                 if lastreftag >= 0 {
405                                         reftile[lastreftag].nexttag = libref.Tag
406                                 }
407                                 lastreftag = libref.Tag
408                         }
409                         pos += len(tiledata) - taglen
410                 }
411                 log.Printf("... %s done, len %d", seqname, pos+taglen)
412         }
413
414         var mask *mask
415         if *regionsFilename != "" {
416                 log.Printf("loading regions from %s", *regionsFilename)
417                 mask, err = makeMask(*regionsFilename, *expandRegions)
418                 if err != nil {
419                         return err
420                 }
421                 log.Printf("before applying mask, len(reftile) == %d", len(reftile))
422                 log.Printf("deleting reftile entries for regions outside %d intervals", mask.Len())
423                 for _, rt := range reftile {
424                         if !mask.Check(strings.TrimPrefix(rt.seqname, "chr"), rt.pos, rt.pos+len(rt.tiledata)) {
425                                 rt.excluded = true
426                         }
427                 }
428                 log.Printf("after applying mask, len(reftile) == %d", len(reftile))
429         }
430
431         type hgvsColSet map[hgvs.Variant][2][]int8
432         encodeHGVS := throttle{Max: len(refseq)}
433         encodeHGVSTodo := map[string]chan hgvsColSet{}
434         tmpHGVSCols := map[string]*os.File{}
435         if *hgvsChunked {
436                 for seqname := range refseq {
437                         var f *os.File
438                         f, err = os.Create(*outputDir + "/tmp." + seqname + ".gob")
439                         if err != nil {
440                                 return err
441                         }
442                         defer os.Remove(f.Name())
443                         bufw := bufio.NewWriterSize(f, 1<<24)
444                         enc := gob.NewEncoder(bufw)
445                         tmpHGVSCols[seqname] = f
446                         todo := make(chan hgvsColSet, 128)
447                         encodeHGVSTodo[seqname] = todo
448                         encodeHGVS.Go(func() error {
449                                 for colset := range todo {
450                                         err := enc.Encode(colset)
451                                         if err != nil {
452                                                 encodeHGVS.Report(err)
453                                                 for range todo {
454                                                 }
455                                                 return err
456                                         }
457                                 }
458                                 return bufw.Flush()
459                         })
460                 }
461         }
462
463         var toMerge [][]int16
464         if *mergeOutput || *hgvsSingle {
465                 toMerge = make([][]int16, len(infiles))
466         }
467         var onehotIndirect [][2][]uint32 // [chunkIndex][axis][index]
468         var onehotChunkSize []uint32
469         var onehotXrefs [][]onehotXref
470         if *onehotSingle || *onlyPCA {
471                 onehotIndirect = make([][2][]uint32, len(infiles))
472                 onehotChunkSize = make([]uint32, len(infiles))
473                 onehotXrefs = make([][]onehotXref, len(infiles))
474         }
475         chunkStartTag := make([]tagID, len(infiles))
476
477         throttleMem := throttle{Max: cmd.threads} // TODO: estimate using mem and data size
478         throttleNumpyMem := throttle{Max: cmd.threads/2 + 1}
479         log.Info("generating annotations and numpy matrix for each slice")
480         var errSkip = errors.New("skip infile")
481         var done int64
482         for infileIdx, infile := range infiles {
483                 infileIdx, infile := infileIdx, infile
484                 throttleMem.Go(func() error {
485                         seq := make(map[tagID][]TileVariant, 50000)
486                         cgs := make(map[string]CompactGenome, len(cmd.cgnames))
487                         f, err := open(infile)
488                         if err != nil {
489                                 return err
490                         }
491                         defer f.Close()
492                         log.Infof("%04d: reading %s", infileIdx, infile)
493                         err = DecodeLibrary(f, strings.HasSuffix(infile, ".gz"), func(ent *LibraryEntry) error {
494                                 for _, tv := range ent.TileVariants {
495                                         if tv.Ref {
496                                                 continue
497                                         }
498                                         // Skip tile with no
499                                         // corresponding ref tile, if
500                                         // mask is in play (we can't
501                                         // determine coordinates for
502                                         // these)
503                                         if mask != nil && reftile[tv.Tag] == nil {
504                                                 continue
505                                         }
506                                         // Skip tile whose
507                                         // corresponding ref tile is
508                                         // outside target regions --
509                                         // unless it's a potential
510                                         // spanning tile.
511                                         if mask != nil && reftile[tv.Tag].excluded &&
512                                                 (int(tv.Tag+1) >= len(tagset) ||
513                                                         (bytes.HasSuffix(tv.Sequence, tagset[tv.Tag+1]) && reftile[tv.Tag+1] != nil && !reftile[tv.Tag+1].excluded)) {
514                                                 continue
515                                         }
516                                         if tv.Tag == cmd.debugTag {
517                                                 log.Printf("infile %d %s tag %d variant %d hash %x", infileIdx, infile, tv.Tag, tv.Variant, tv.Blake2b[:3])
518                                         }
519                                         variants := seq[tv.Tag]
520                                         if len(variants) == 0 {
521                                                 variants = make([]TileVariant, 100)
522                                         }
523                                         for len(variants) <= int(tv.Variant) {
524                                                 variants = append(variants, TileVariant{})
525                                         }
526                                         variants[int(tv.Variant)] = tv
527                                         seq[tv.Tag] = variants
528                                 }
529                                 for _, cg := range ent.CompactGenomes {
530                                         if cmd.filter.MaxTag >= 0 && cg.StartTag > tagID(cmd.filter.MaxTag) {
531                                                 return errSkip
532                                         }
533                                         if !matchGenome.MatchString(cg.Name) {
534                                                 continue
535                                         }
536                                         // pad to full slice size
537                                         // to avoid out-of-bounds
538                                         // checks later
539                                         if sliceSize := 2 * int(cg.EndTag-cg.StartTag); len(cg.Variants) < sliceSize {
540                                                 cg.Variants = append(cg.Variants, make([]tileVariantID, sliceSize-len(cg.Variants))...)
541                                         }
542                                         cgs[cg.Name] = cg
543                                 }
544                                 return nil
545                         })
546                         if err == errSkip {
547                                 return nil
548                         } else if err != nil {
549                                 return fmt.Errorf("%04d: DecodeLibrary(%s): err", infileIdx, infile)
550                         }
551                         tagstart := cgs[cmd.cgnames[0]].StartTag
552                         tagend := cgs[cmd.cgnames[0]].EndTag
553                         chunkStartTag[infileIdx] = tagstart
554
555                         // TODO: filters
556
557                         log.Infof("%04d: renumber/dedup variants for tags %d-%d", infileIdx, tagstart, tagend)
558                         variantRemap := make([][]tileVariantID, tagend-tagstart)
559                         throttleCPU := throttle{Max: runtime.GOMAXPROCS(0)}
560                         for tag, variants := range seq {
561                                 tag, variants := tag, variants
562                                 throttleCPU.Go(func() error {
563                                         alleleCoverage := 0
564                                         count := make(map[[blake2b.Size256]byte]int, len(variants))
565
566                                         rt := reftile[tag]
567                                         if rt != nil {
568                                                 count[blake2b.Sum256(rt.tiledata)] = 0
569                                         }
570
571                                         for cgname, cg := range cgs {
572                                                 idx := int(tag-tagstart) * 2
573                                                 for allele := 0; allele < 2; allele++ {
574                                                         v := cg.Variants[idx+allele]
575                                                         if v > 0 && len(variants[v].Sequence) > 0 {
576                                                                 count[variants[v].Blake2b]++
577                                                                 alleleCoverage++
578                                                         }
579                                                         if v > 0 && tag == cmd.debugTag {
580                                                                 log.Printf("tag %d cg %s allele %d tv %d hash %x count is now %d", tag, cgname, allele, v, variants[v].Blake2b[:3], count[variants[v].Blake2b])
581                                                         }
582                                                 }
583                                         }
584                                         if alleleCoverage < cmd.minCoverage*2 {
585                                                 idx := int(tag-tagstart) * 2
586                                                 for _, cg := range cgs {
587                                                         cg.Variants[idx] = 0
588                                                         cg.Variants[idx+1] = 0
589                                                 }
590                                                 if tag == cmd.debugTag {
591                                                         log.Printf("tag %d alleleCoverage %d < min %d, sample data wiped", tag, alleleCoverage, cmd.minCoverage*2)
592                                                 }
593                                                 return nil
594                                         }
595
596                                         // hash[i] will be the hash of
597                                         // the variant(s) that should
598                                         // be at rank i (0-based).
599                                         hash := make([][blake2b.Size256]byte, 0, len(count))
600                                         for b := range count {
601                                                 hash = append(hash, b)
602                                         }
603                                         sort.Slice(hash, func(i, j int) bool {
604                                                 bi, bj := &hash[i], &hash[j]
605                                                 if ci, cj := count[*bi], count[*bj]; ci != cj {
606                                                         return ci > cj
607                                                 } else {
608                                                         return bytes.Compare((*bi)[:], (*bj)[:]) < 0
609                                                 }
610                                         })
611                                         // rank[b] will be the 1-based
612                                         // new variant number for
613                                         // variants whose hash is b.
614                                         rank := make(map[[blake2b.Size256]byte]tileVariantID, len(hash))
615                                         for i, h := range hash {
616                                                 rank[h] = tileVariantID(i + 1)
617                                         }
618                                         if tag == cmd.debugTag {
619                                                 for h, r := range rank {
620                                                         log.Printf("tag %d rank(%x) = %v", tag, h[:3], r)
621                                                 }
622                                         }
623                                         // remap[v] will be the new
624                                         // variant number for original
625                                         // variant number v.
626                                         remap := make([]tileVariantID, len(variants))
627                                         for i, tv := range variants {
628                                                 remap[i] = rank[tv.Blake2b]
629                                         }
630                                         if tag == cmd.debugTag {
631                                                 for in, out := range remap {
632                                                         if out > 0 {
633                                                                 log.Printf("tag %d remap %d => %d", tag, in, out)
634                                                         }
635                                                 }
636                                         }
637                                         variantRemap[tag-tagstart] = remap
638                                         if rt != nil {
639                                                 refrank := rank[blake2b.Sum256(rt.tiledata)]
640                                                 if tag == cmd.debugTag {
641                                                         log.Printf("tag %d reftile variant %d => %d", tag, rt.variant, refrank)
642                                                 }
643                                                 rt.variant = refrank
644                                         }
645                                         return nil
646                                 })
647                         }
648                         throttleCPU.Wait()
649
650                         var onehotChunk [][]int8
651                         var onehotXref []onehotXref
652
653                         var annotationsFilename string
654                         if *onlyPCA {
655                                 annotationsFilename = "/dev/null"
656                         } else {
657                                 annotationsFilename = fmt.Sprintf("%s/matrix.%04d.annotations.csv", *outputDir, infileIdx)
658                                 log.Infof("%04d: writing %s", infileIdx, annotationsFilename)
659                         }
660                         annof, err := os.Create(annotationsFilename)
661                         if err != nil {
662                                 return err
663                         }
664                         annow := bufio.NewWriterSize(annof, 1<<20)
665                         outcol := 0
666                         for tag := tagstart; tag < tagend; tag++ {
667                                 rt := reftile[tag]
668                                 if rt == nil && mask != nil {
669                                         // With no ref tile, we don't
670                                         // have coordinates to say
671                                         // this is in the desired
672                                         // regions -- so it's not.
673                                         // TODO: handle ref spanning
674                                         // tile case.
675                                         continue
676                                 }
677                                 if rt != nil && rt.excluded {
678                                         // TODO: don't skip yet --
679                                         // first check for spanning
680                                         // tile variants that
681                                         // intersect non-excluded ref
682                                         // tiles.
683                                         continue
684                                 }
685                                 if cmd.filter.MaxTag >= 0 && tag > tagID(cmd.filter.MaxTag) {
686                                         break
687                                 }
688                                 remap := variantRemap[tag-tagstart]
689                                 if remap == nil {
690                                         // was not assigned above,
691                                         // because minCoverage
692                                         outcol++
693                                         continue
694                                 }
695                                 maxv := tileVariantID(0)
696                                 for _, v := range remap {
697                                         if maxv < v {
698                                                 maxv = v
699                                         }
700                                 }
701                                 if *onehotChunked || *onehotSingle || *onlyPCA {
702                                         onehot, xrefs := cmd.tv2homhet(cgs, maxv, remap, tag, tagstart, seq)
703                                         if tag == cmd.debugTag {
704                                                 log.WithFields(logrus.Fields{
705                                                         "onehot": onehot,
706                                                         "xrefs":  xrefs,
707                                                 }).Info("tv2homhet()")
708                                         }
709                                         onehotChunk = append(onehotChunk, onehot...)
710                                         onehotXref = append(onehotXref, xrefs...)
711                                 }
712                                 if *onlyPCA {
713                                         outcol++
714                                         continue
715                                 }
716                                 if rt == nil {
717                                         // Reference does not use any
718                                         // variant of this tile
719                                         //
720                                         // TODO: diff against the
721                                         // relevant portion of the
722                                         // ref's spanning tile
723                                         outcol++
724                                         continue
725                                 }
726                                 fmt.Fprintf(annow, "%d,%d,%d,=,%s,%d,,,\n", tag, outcol, rt.variant, rt.seqname, rt.pos)
727                                 variants := seq[tag]
728                                 reftilestr := strings.ToUpper(string(rt.tiledata))
729
730                                 done := make([]bool, maxv+1)
731                                 variantDiffs := make([][]hgvs.Variant, maxv+1)
732                                 for v, tv := range variants {
733                                         v := remap[v]
734                                         if v == 0 || v == rt.variant || done[v] {
735                                                 continue
736                                         } else {
737                                                 done[v] = true
738                                         }
739                                         if len(tv.Sequence) < taglen {
740                                                 continue
741                                         }
742                                         // if reftilestr doesn't end
743                                         // in the same tag as tv,
744                                         // extend reftilestr with
745                                         // following ref tiles until
746                                         // it does (up to an arbitrary
747                                         // sanity-check limit)
748                                         reftilestr := reftilestr
749                                         endtagstr := strings.ToUpper(string(tv.Sequence[len(tv.Sequence)-taglen:]))
750                                         for i, rt := 0, rt; i < annotationMaxTileSpan && !strings.HasSuffix(reftilestr, endtagstr) && rt.nexttag >= 0; i++ {
751                                                 rt = reftile[rt.nexttag]
752                                                 if rt == nil {
753                                                         break
754                                                 }
755                                                 reftilestr += strings.ToUpper(string(rt.tiledata[taglen:]))
756                                         }
757                                         if mask != nil && !mask.Check(strings.TrimPrefix(rt.seqname, "chr"), rt.pos, rt.pos+len(reftilestr)) {
758                                                 continue
759                                         }
760                                         if !strings.HasSuffix(reftilestr, endtagstr) {
761                                                 fmt.Fprintf(annow, "%d,%d,%d,,%s,%d,,,\n", tag, outcol, v, rt.seqname, rt.pos)
762                                                 continue
763                                         }
764                                         if lendiff := len(reftilestr) - len(tv.Sequence); lendiff < -1000 || lendiff > 1000 {
765                                                 fmt.Fprintf(annow, "%d,%d,%d,,%s,%d,,,\n", tag, outcol, v, rt.seqname, rt.pos)
766                                                 continue
767                                         }
768                                         diffs, _ := hgvs.Diff(reftilestr, strings.ToUpper(string(tv.Sequence)), 0)
769                                         for i := range diffs {
770                                                 diffs[i].Position += rt.pos
771                                         }
772                                         for _, diff := range diffs {
773                                                 fmt.Fprintf(annow, "%d,%d,%d,%s:g.%s,%s,%d,%s,%s,%s\n", tag, outcol, v, rt.seqname, diff.String(), rt.seqname, diff.Position, diff.Ref, diff.New, diff.Left)
774                                         }
775                                         if *hgvsChunked {
776                                                 variantDiffs[v] = diffs
777                                         }
778                                 }
779                                 if *hgvsChunked {
780                                         // We can now determine, for each HGVS
781                                         // variant (diff) in this reftile
782                                         // region, whether a given genome
783                                         // phase/allele (1) has the variant, (0) has
784                                         // =ref or a different variant in that
785                                         // position, or (-1) is lacking
786                                         // coverage / couldn't be diffed.
787                                         hgvsCol := hgvsColSet{}
788                                         for _, diffs := range variantDiffs {
789                                                 for _, diff := range diffs {
790                                                         if _, ok := hgvsCol[diff]; ok {
791                                                                 continue
792                                                         }
793                                                         hgvsCol[diff] = [2][]int8{
794                                                                 make([]int8, len(cmd.cgnames)),
795                                                                 make([]int8, len(cmd.cgnames)),
796                                                         }
797                                                 }
798                                         }
799                                         for row, name := range cmd.cgnames {
800                                                 variants := cgs[name].Variants[(tag-tagstart)*2:]
801                                                 for ph := 0; ph < 2; ph++ {
802                                                         v := variants[ph]
803                                                         if int(v) >= len(remap) {
804                                                                 v = 0
805                                                         } else {
806                                                                 v = remap[v]
807                                                         }
808                                                         if v == rt.variant {
809                                                                 // hgvsCol[*][ph][row] is already 0
810                                                         } else if len(variantDiffs[v]) == 0 {
811                                                                 // lacking coverage / couldn't be diffed
812                                                                 for _, col := range hgvsCol {
813                                                                         col[ph][row] = -1
814                                                                 }
815                                                         } else {
816                                                                 for _, diff := range variantDiffs[v] {
817                                                                         hgvsCol[diff][ph][row] = 1
818                                                                 }
819                                                         }
820                                                 }
821                                         }
822                                         for diff, colpair := range hgvsCol {
823                                                 allele2homhet(colpair)
824                                                 if !cmd.filterHGVScolpair(colpair) {
825                                                         delete(hgvsCol, diff)
826                                                 }
827                                         }
828                                         if len(hgvsCol) > 0 {
829                                                 encodeHGVSTodo[rt.seqname] <- hgvsCol
830                                         }
831                                 }
832                                 outcol++
833                         }
834                         err = annow.Flush()
835                         if err != nil {
836                                 return err
837                         }
838                         err = annof.Close()
839                         if err != nil {
840                                 return err
841                         }
842
843                         if *onehotChunked {
844                                 // transpose onehotChunk[col][row] to numpy[row*ncols+col]
845                                 rows := len(cmd.cgnames)
846                                 cols := len(onehotChunk)
847                                 log.Infof("%04d: preparing onehot numpy (rows=%d, cols=%d, mem=%d)", infileIdx, rows, cols, rows*cols)
848                                 throttleNumpyMem.Acquire()
849                                 out := onehotcols2int8(onehotChunk)
850                                 fnm := fmt.Sprintf("%s/onehot.%04d.npy", *outputDir, infileIdx)
851                                 err = writeNumpyInt8(fnm, out, rows, cols)
852                                 if err != nil {
853                                         return err
854                                 }
855                                 fnm = fmt.Sprintf("%s/onehot-columns.%04d.npy", *outputDir, infileIdx)
856                                 err = writeNumpyInt32(fnm, onehotXref2int32(onehotXref), 4, len(onehotXref))
857                                 if err != nil {
858                                         return err
859                                 }
860                                 debug.FreeOSMemory()
861                                 throttleNumpyMem.Release()
862                         }
863                         if *onehotSingle || *onlyPCA {
864                                 onehotIndirect[infileIdx] = onehotChunk2Indirect(onehotChunk)
865                                 onehotChunkSize[infileIdx] = uint32(len(onehotChunk))
866                                 onehotXrefs[infileIdx] = onehotXref
867                                 n := len(onehotIndirect[infileIdx][0])
868                                 log.Infof("%04d: keeping onehot coordinates in memory (n=%d, mem=%d)", infileIdx, n, n*8*2)
869                         }
870                         if !(*onehotSingle || *onehotChunked || *onlyPCA) || *mergeOutput || *hgvsSingle {
871                                 log.Infof("%04d: preparing numpy (rows=%d, cols=%d)", infileIdx, len(cmd.cgnames), 2*outcol)
872                                 throttleNumpyMem.Acquire()
873                                 rows := len(cmd.cgnames)
874                                 cols := 2 * outcol
875                                 out := make([]int16, rows*cols)
876                                 for row, name := range cmd.cgnames {
877                                         outidx := row * cols
878                                         for col, v := range cgs[name].Variants {
879                                                 tag := tagstart + tagID(col/2)
880                                                 if cmd.filter.MaxTag >= 0 && tag > tagID(cmd.filter.MaxTag) {
881                                                         break
882                                                 }
883                                                 if rt := reftile[tag]; rt == nil || rt.excluded {
884                                                         continue
885                                                 }
886                                                 if v == 0 {
887                                                         out[outidx] = 0 // tag not found / spanning tile
888                                                 } else if variants, ok := seq[tag]; ok && int(v) < len(variants) && len(variants[v].Sequence) > 0 {
889                                                         out[outidx] = int16(variantRemap[tag-tagstart][v])
890                                                 } else {
891                                                         out[outidx] = -1 // low quality tile variant
892                                                 }
893                                                 if tag == cmd.debugTag {
894                                                         log.Printf("tag %d row %d col %d outidx %d v %d out %d", tag, row, col, outidx, v, out[outidx])
895                                                 }
896                                                 outidx++
897                                         }
898                                 }
899                                 seq = nil
900                                 cgs = nil
901                                 debug.FreeOSMemory()
902                                 throttleNumpyMem.Release()
903                                 if *mergeOutput || *hgvsSingle {
904                                         log.Infof("%04d: matrix fragment %d rows x %d cols", infileIdx, rows, cols)
905                                         toMerge[infileIdx] = out
906                                 }
907                                 if !*mergeOutput && !*onehotChunked && !*onehotSingle {
908                                         fnm := fmt.Sprintf("%s/matrix.%04d.npy", *outputDir, infileIdx)
909                                         err = writeNumpyInt16(fnm, out, rows, cols)
910                                         if err != nil {
911                                                 return err
912                                         }
913                                 }
914                         }
915                         debug.FreeOSMemory()
916                         log.Infof("%s: done (%d/%d)", infile, int(atomic.AddInt64(&done, 1)), len(infiles))
917                         return nil
918                 })
919         }
920         if err = throttleMem.Wait(); err != nil {
921                 return err
922         }
923
924         if *hgvsChunked {
925                 log.Info("flushing hgvsCols temp files")
926                 for seqname := range refseq {
927                         close(encodeHGVSTodo[seqname])
928                 }
929                 err = encodeHGVS.Wait()
930                 if err != nil {
931                         return err
932                 }
933                 for seqname := range refseq {
934                         log.Infof("%s: reading hgvsCols from temp file", seqname)
935                         f := tmpHGVSCols[seqname]
936                         _, err = f.Seek(0, io.SeekStart)
937                         if err != nil {
938                                 return err
939                         }
940                         var hgvsCols hgvsColSet
941                         dec := gob.NewDecoder(bufio.NewReaderSize(f, 1<<24))
942                         for err == nil {
943                                 err = dec.Decode(&hgvsCols)
944                         }
945                         if err != io.EOF {
946                                 return err
947                         }
948                         log.Infof("%s: sorting %d hgvs variants", seqname, len(hgvsCols))
949                         variants := make([]hgvs.Variant, 0, len(hgvsCols))
950                         for v := range hgvsCols {
951                                 variants = append(variants, v)
952                         }
953                         sort.Slice(variants, func(i, j int) bool {
954                                 vi, vj := &variants[i], &variants[j]
955                                 if vi.Position != vj.Position {
956                                         return vi.Position < vj.Position
957                                 } else if vi.Ref != vj.Ref {
958                                         return vi.Ref < vj.Ref
959                                 } else {
960                                         return vi.New < vj.New
961                                 }
962                         })
963                         rows := len(cmd.cgnames)
964                         cols := len(variants) * 2
965                         log.Infof("%s: building hgvs matrix (rows=%d, cols=%d, mem=%d)", seqname, rows, cols, rows*cols)
966                         out := make([]int8, rows*cols)
967                         for varIdx, variant := range variants {
968                                 hgvsCols := hgvsCols[variant]
969                                 for row := range cmd.cgnames {
970                                         for ph := 0; ph < 2; ph++ {
971                                                 out[row*cols+varIdx+ph] = hgvsCols[ph][row]
972                                         }
973                                 }
974                         }
975                         err = writeNumpyInt8(fmt.Sprintf("%s/hgvs.%s.npy", *outputDir, seqname), out, rows, cols)
976                         if err != nil {
977                                 return err
978                         }
979                         out = nil
980
981                         fnm := fmt.Sprintf("%s/hgvs.%s.annotations.csv", *outputDir, seqname)
982                         log.Infof("%s: writing hgvs column labels to %s", seqname, fnm)
983                         var hgvsLabels bytes.Buffer
984                         for varIdx, variant := range variants {
985                                 fmt.Fprintf(&hgvsLabels, "%d,%s:g.%s\n", varIdx, seqname, variant.String())
986                         }
987                         err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0666)
988                         if err != nil {
989                                 return err
990                         }
991                 }
992         }
993
994         if *mergeOutput || *hgvsSingle {
995                 var annow *bufio.Writer
996                 var annof *os.File
997                 if *mergeOutput {
998                         annoFilename := fmt.Sprintf("%s/matrix.annotations.csv", *outputDir)
999                         annof, err = os.Create(annoFilename)
1000                         if err != nil {
1001                                 return err
1002                         }
1003                         annow = bufio.NewWriterSize(annof, 1<<20)
1004                 }
1005
1006                 rows := len(cmd.cgnames)
1007                 cols := 0
1008                 for _, chunk := range toMerge {
1009                         cols += len(chunk) / rows
1010                 }
1011                 log.Infof("merging output matrix (rows=%d, cols=%d, mem=%d) and annotations", rows, cols, rows*cols*2)
1012                 var out []int16
1013                 if *mergeOutput {
1014                         out = make([]int16, rows*cols)
1015                 }
1016                 hgvsCols := map[string][2][]int16{} // hgvs -> [[g0,g1,g2,...], [g0,g1,g2,...]] (slice of genomes for each phase)
1017                 startcol := 0
1018                 for outIdx, chunk := range toMerge {
1019                         chunkcols := len(chunk) / rows
1020                         if *mergeOutput {
1021                                 for row := 0; row < rows; row++ {
1022                                         copy(out[row*cols+startcol:], chunk[row*chunkcols:(row+1)*chunkcols])
1023                                 }
1024                         }
1025                         toMerge[outIdx] = nil
1026
1027                         annotationsFilename := fmt.Sprintf("%s/matrix.%04d.annotations.csv", *outputDir, outIdx)
1028                         log.Infof("reading %s", annotationsFilename)
1029                         buf, err := os.ReadFile(annotationsFilename)
1030                         if err != nil {
1031                                 return err
1032                         }
1033                         if *mergeOutput {
1034                                 err = os.Remove(annotationsFilename)
1035                                 if err != nil {
1036                                         return err
1037                                 }
1038                         }
1039                         for _, line := range bytes.Split(buf, []byte{'\n'}) {
1040                                 if len(line) == 0 {
1041                                         continue
1042                                 }
1043                                 fields := bytes.SplitN(line, []byte{','}, 9)
1044                                 tag, _ := strconv.Atoi(string(fields[0]))
1045                                 incol, _ := strconv.Atoi(string(fields[1]))
1046                                 tileVariant, _ := strconv.Atoi(string(fields[2]))
1047                                 hgvsID := string(fields[3])
1048                                 seqname := string(fields[4])
1049                                 pos, _ := strconv.Atoi(string(fields[5]))
1050                                 refseq := fields[6]
1051                                 if hgvsID == "" {
1052                                         // Null entry for un-diffable
1053                                         // tile variant
1054                                         continue
1055                                 }
1056                                 if hgvsID == "=" {
1057                                         // Null entry for ref tile
1058                                         continue
1059                                 }
1060                                 if mask != nil && !mask.Check(strings.TrimPrefix(seqname, "chr"), pos, pos+len(refseq)) {
1061                                         // The tile intersects one of
1062                                         // the selected regions, but
1063                                         // this particular HGVS
1064                                         // variant does not.
1065                                         continue
1066                                 }
1067                                 hgvsColPair := hgvsCols[hgvsID]
1068                                 if hgvsColPair[0] == nil {
1069                                         // values in new columns start
1070                                         // out as -1 ("no data yet")
1071                                         // or 0 ("=ref") here, may
1072                                         // change to 1 ("hgvs variant
1073                                         // present") below, either on
1074                                         // this line or a future line.
1075                                         hgvsColPair = [2][]int16{make([]int16, len(cmd.cgnames)), make([]int16, len(cmd.cgnames))}
1076                                         rt, ok := reftile[tagID(tag)]
1077                                         if !ok {
1078                                                 err = fmt.Errorf("bug: seeing annotations for tag %d, but it has no reftile entry", tag)
1079                                                 return err
1080                                         }
1081                                         for ph := 0; ph < 2; ph++ {
1082                                                 for row := 0; row < rows; row++ {
1083                                                         v := chunk[row*chunkcols+incol*2+ph]
1084                                                         if tileVariantID(v) == rt.variant {
1085                                                                 hgvsColPair[ph][row] = 0
1086                                                         } else {
1087                                                                 hgvsColPair[ph][row] = -1
1088                                                         }
1089                                                 }
1090                                         }
1091                                         hgvsCols[hgvsID] = hgvsColPair
1092                                         if annow != nil {
1093                                                 hgvsref := hgvs.Variant{
1094                                                         Position: pos,
1095                                                         Ref:      string(refseq),
1096                                                         New:      string(refseq),
1097                                                 }
1098                                                 fmt.Fprintf(annow, "%d,%d,%d,%s:g.%s,%s,%d,%s,%s,%s\n", tag, incol+startcol/2, rt.variant, seqname, hgvsref.String(), seqname, pos, refseq, refseq, fields[8])
1099                                         }
1100                                 }
1101                                 if annow != nil {
1102                                         fmt.Fprintf(annow, "%d,%d,%d,%s,%s,%d,%s,%s,%s\n", tag, incol+startcol/2, tileVariant, hgvsID, seqname, pos, refseq, fields[7], fields[8])
1103                                 }
1104                                 for ph := 0; ph < 2; ph++ {
1105                                         for row := 0; row < rows; row++ {
1106                                                 v := chunk[row*chunkcols+incol*2+ph]
1107                                                 if int(v) == tileVariant {
1108                                                         hgvsColPair[ph][row] = 1
1109                                                 }
1110                                         }
1111                                 }
1112                         }
1113
1114                         startcol += chunkcols
1115                 }
1116                 if *mergeOutput {
1117                         err = annow.Flush()
1118                         if err != nil {
1119                                 return err
1120                         }
1121                         err = annof.Close()
1122                         if err != nil {
1123                                 return err
1124                         }
1125                         err = writeNumpyInt16(fmt.Sprintf("%s/matrix.npy", *outputDir), out, rows, cols)
1126                         if err != nil {
1127                                 return err
1128                         }
1129                 }
1130                 out = nil
1131
1132                 if *hgvsSingle {
1133                         cols = len(hgvsCols) * 2
1134                         log.Printf("building hgvs-based matrix: %d rows x %d cols", rows, cols)
1135                         out = make([]int16, rows*cols)
1136                         hgvsIDs := make([]string, 0, cols/2)
1137                         for hgvsID := range hgvsCols {
1138                                 hgvsIDs = append(hgvsIDs, hgvsID)
1139                         }
1140                         sort.Strings(hgvsIDs)
1141                         var hgvsLabels bytes.Buffer
1142                         for idx, hgvsID := range hgvsIDs {
1143                                 fmt.Fprintf(&hgvsLabels, "%d,%s\n", idx, hgvsID)
1144                                 for ph := 0; ph < 2; ph++ {
1145                                         hgvscol := hgvsCols[hgvsID][ph]
1146                                         for row, val := range hgvscol {
1147                                                 out[row*cols+idx*2+ph] = val
1148                                         }
1149                                 }
1150                         }
1151                         err = writeNumpyInt16(fmt.Sprintf("%s/hgvs.npy", *outputDir), out, rows, cols)
1152                         if err != nil {
1153                                 return err
1154                         }
1155
1156                         fnm := fmt.Sprintf("%s/hgvs.annotations.csv", *outputDir)
1157                         log.Printf("writing hgvs labels: %s", fnm)
1158                         err = ioutil.WriteFile(fnm, hgvsLabels.Bytes(), 0777)
1159                         if err != nil {
1160                                 return err
1161                         }
1162                 }
1163         }
1164         if *onehotSingle || *onlyPCA {
1165                 nzCount := 0
1166                 for _, part := range onehotIndirect {
1167                         nzCount += len(part[0])
1168                 }
1169                 onehot := make([]uint32, nzCount*2) // [r,r,r,...,c,c,c,...]
1170                 var xrefs []onehotXref
1171                 chunkOffset := uint32(0)
1172                 outcol := 0
1173                 for i, part := range onehotIndirect {
1174                         for i := range part[1] {
1175                                 part[1][i] += chunkOffset
1176                         }
1177                         copy(onehot[outcol:], part[0])
1178                         copy(onehot[outcol+nzCount:], part[1])
1179                         xrefs = append(xrefs, onehotXrefs[i]...)
1180
1181                         outcol += len(part[0])
1182                         chunkOffset += onehotChunkSize[i]
1183
1184                         part[0] = nil
1185                         part[1] = nil
1186                         onehotXrefs[i] = nil
1187                         debug.FreeOSMemory()
1188                 }
1189                 if *onehotSingle {
1190                         fnm := fmt.Sprintf("%s/onehot.npy", *outputDir)
1191                         err = writeNumpyUint32(fnm, onehot, 2, nzCount)
1192                         if err != nil {
1193                                 return err
1194                         }
1195                         fnm = fmt.Sprintf("%s/onehot-columns.npy", *outputDir)
1196                         err = writeNumpyInt32(fnm, onehotXref2int32(xrefs), 5, len(xrefs))
1197                         if err != nil {
1198                                 return err
1199                         }
1200                         fnm = fmt.Sprintf("%s/stats.json", *outputDir)
1201                         j, err := json.Marshal(map[string]interface{}{
1202                                 "pvalueCallCount": cmd.pvalueCallCount,
1203                         })
1204                         if err != nil {
1205                                 return err
1206                         }
1207                         err = os.WriteFile(fnm, j, 0777)
1208                         if err != nil {
1209                                 return err
1210                         }
1211                 }
1212                 if *onlyPCA {
1213                         cols := 0
1214                         for _, c := range onehot[nzCount:] {
1215                                 if int(c) >= cols {
1216                                         cols = int(c) + 1
1217                                 }
1218                         }
1219                         if cols == 0 {
1220                                 return fmt.Errorf("cannot do PCA: one-hot matrix is empty")
1221                         }
1222                         log.Printf("have %d one-hot cols", cols)
1223                         stride := 1
1224                         for *maxPCATiles > 0 && cols > *maxPCATiles*2 {
1225                                 cols = (cols + 1) / 2
1226                                 stride = stride * 2
1227                         }
1228                         if cols%2 == 1 {
1229                                 // we work with pairs of columns
1230                                 cols++
1231                         }
1232                         log.Printf("creating full matrix (%d rows) and training matrix (%d rows) with %d cols, stride %d", len(cmd.cgnames), cmd.trainingSetSize, cols, stride)
1233                         mtxFull := mat.NewDense(len(cmd.cgnames), cols, nil)
1234                         mtxTrain := mat.NewDense(cmd.trainingSetSize, cols, nil)
1235                         for i, c := range onehot[nzCount:] {
1236                                 if int(c/2)%stride == 0 {
1237                                         outcol := int(c/2)/stride*2 + int(c)%2
1238                                         mtxFull.Set(int(onehot[i]), outcol, 1)
1239                                         if trainRow := cmd.trainingSet[int(onehot[i])]; trainRow >= 0 {
1240                                                 mtxTrain.Set(trainRow, outcol, 1)
1241                                         }
1242                                 }
1243                         }
1244                         log.Print("fitting")
1245                         transformer := nlp.NewPCA(cmd.pcaComponents)
1246                         transformer.Fit(mtxTrain.T())
1247                         log.Printf("transforming")
1248                         pca, err := transformer.Transform(mtxFull.T())
1249                         if err != nil {
1250                                 return err
1251                         }
1252                         pca = pca.T()
1253                         outrows, outcols := pca.Dims()
1254                         log.Printf("copying result to numpy output array: %d rows, %d cols", outrows, outcols)
1255                         out := make([]float64, outrows*outcols)
1256                         for i := 0; i < outrows; i++ {
1257                                 for j := 0; j < outcols; j++ {
1258                                         out[i*outcols+j] = pca.At(i, j)
1259                                 }
1260                         }
1261                         fnm := fmt.Sprintf("%s/pca.npy", *outputDir)
1262                         log.Printf("writing numpy: %s", fnm)
1263                         output, err := os.OpenFile(fnm, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
1264                         if err != nil {
1265                                 return err
1266                         }
1267                         npw, err := gonpy.NewWriter(nopCloser{output})
1268                         if err != nil {
1269                                 return fmt.Errorf("gonpy.NewWriter: %w", err)
1270                         }
1271                         npw.Shape = []int{outrows, outcols}
1272                         err = npw.WriteFloat64(out)
1273                         if err != nil {
1274                                 return fmt.Errorf("WriteFloat64: %w", err)
1275                         }
1276                         err = output.Close()
1277                         if err != nil {
1278                                 return err
1279                         }
1280                         log.Print("done")
1281
1282                         samplesOutFilename := *outputDir + "/samples.csv"
1283                         log.Infof("writing sample metadata to %s", samplesOutFilename)
1284                         var f *os.File
1285                         f, err = os.Create(samplesOutFilename)
1286                         if err != nil {
1287                                 return err
1288                         }
1289                         defer f.Close()
1290                         for i, si := range cmd.samples {
1291                                 var cc, tv string
1292                                 if si.isCase {
1293                                         cc = "1"
1294                                 } else if si.isControl {
1295                                         cc = "0"
1296                                 }
1297                                 if si.isTraining {
1298                                         tv = "1"
1299                                 } else {
1300                                         tv = "0"
1301                                 }
1302                                 var pcavals string
1303                                 for c := 0; c < outcols; c++ {
1304                                         pcavals += fmt.Sprintf(",%f", pca.At(i, c))
1305                                 }
1306                                 _, err = fmt.Fprintf(f, "%d,%s,%s,%s%s\n", i, si.id, cc, tv, pcavals)
1307                                 if err != nil {
1308                                         err = fmt.Errorf("write %s: %w", samplesOutFilename, err)
1309                                         return err
1310                                 }
1311                         }
1312                         err = f.Close()
1313                         if err != nil {
1314                                 err = fmt.Errorf("close %s: %w", samplesOutFilename, err)
1315                                 return err
1316                         }
1317                         log.Print("done")
1318                 }
1319         }
1320         if !*mergeOutput && !*onehotChunked && !*onehotSingle && !*onlyPCA {
1321                 tagoffsetFilename := *outputDir + "/chunk-tag-offset.csv"
1322                 log.Infof("writing tag offsets to %s", tagoffsetFilename)
1323                 var f *os.File
1324                 f, err = os.Create(tagoffsetFilename)
1325                 if err != nil {
1326                         return err
1327                 }
1328                 defer f.Close()
1329                 for idx, offset := range chunkStartTag {
1330                         _, err = fmt.Fprintf(f, "%q,%d\n", fmt.Sprintf("matrix.%04d.npy", idx), offset)
1331                         if err != nil {
1332                                 err = fmt.Errorf("write %s: %w", tagoffsetFilename, err)
1333                                 return err
1334                         }
1335                 }
1336                 err = f.Close()
1337                 if err != nil {
1338                         err = fmt.Errorf("close %s: %w", tagoffsetFilename, err)
1339                         return err
1340                 }
1341         }
1342
1343         return nil
1344 }
1345
1346 type sampleInfo struct {
1347         id            string
1348         isCase        bool
1349         isControl     bool
1350         isTraining    bool
1351         isValidation  bool
1352         pcaComponents []float64
1353 }
1354
1355 // Read samples.csv file with case/control and training/validation
1356 // flags.
1357 func loadSampleInfo(samplesFilename string) ([]sampleInfo, error) {
1358         var si []sampleInfo
1359         f, err := open(samplesFilename)
1360         if err != nil {
1361                 return nil, err
1362         }
1363         buf, err := io.ReadAll(f)
1364         f.Close()
1365         if err != nil {
1366                 return nil, err
1367         }
1368         lineNum := 0
1369         for _, csv := range bytes.Split(buf, []byte{'\n'}) {
1370                 lineNum++
1371                 if len(csv) == 0 {
1372                         continue
1373                 }
1374                 split := strings.Split(string(csv), ",")
1375                 if len(split) < 4 {
1376                         return nil, fmt.Errorf("%d fields < 4 in %s line %d: %q", len(split), samplesFilename, lineNum, csv)
1377                 }
1378                 if split[0] == "Index" && split[1] == "SampleID" && split[2] == "CaseControl" && split[3] == "TrainingValidation" {
1379                         continue
1380                 }
1381                 idx, err := strconv.Atoi(split[0])
1382                 if err != nil {
1383                         if lineNum == 1 {
1384                                 return nil, fmt.Errorf("header does not look right: %q", csv)
1385                         }
1386                         return nil, fmt.Errorf("%s line %d: index: %s", samplesFilename, lineNum, err)
1387                 }
1388                 if idx != len(si) {
1389                         return nil, fmt.Errorf("%s line %d: index %d out of order", samplesFilename, lineNum, idx)
1390                 }
1391                 var pcaComponents []float64
1392                 if len(split) > 4 {
1393                         for _, s := range split[4:] {
1394                                 f, err := strconv.ParseFloat(s, 64)
1395                                 if err != nil {
1396                                         return nil, fmt.Errorf("%s line %d: cannot parse float %q: %s", samplesFilename, lineNum, s, err)
1397                                 }
1398                                 pcaComponents = append(pcaComponents, f)
1399                         }
1400                 }
1401                 si = append(si, sampleInfo{
1402                         id:            split[1],
1403                         isCase:        split[2] == "1",
1404                         isControl:     split[2] == "0",
1405                         isTraining:    split[3] == "1",
1406                         isValidation:  split[3] == "0",
1407                         pcaComponents: pcaComponents,
1408                 })
1409         }
1410         return si, nil
1411 }
1412
1413 func (cmd *sliceNumpy) filterHGVScolpair(colpair [2][]int8) bool {
1414         if cmd.chi2PValue >= 1 {
1415                 return true
1416         }
1417         col0 := make([]bool, 0, len(cmd.chi2Cases))
1418         col1 := make([]bool, 0, len(cmd.chi2Cases))
1419         cases := make([]bool, 0, len(cmd.chi2Cases))
1420         for i, c := range cmd.chi2Cases {
1421                 if colpair[0][i] < 0 {
1422                         continue
1423                 }
1424                 col0 = append(col0, colpair[0][i] != 0)
1425                 col1 = append(col1, colpair[1][i] != 0)
1426                 cases = append(cases, c)
1427         }
1428         return len(cases) >= cmd.minCoverage &&
1429                 (pvalue(col0, cases) <= cmd.chi2PValue || pvalue(col1, cases) <= cmd.chi2PValue)
1430 }
1431
1432 func writeNumpyUint32(fnm string, out []uint32, rows, cols int) error {
1433         output, err := os.Create(fnm)
1434         if err != nil {
1435                 return err
1436         }
1437         defer output.Close()
1438         bufw := bufio.NewWriterSize(output, 1<<26)
1439         npw, err := gonpy.NewWriter(nopCloser{bufw})
1440         if err != nil {
1441                 return err
1442         }
1443         log.WithFields(log.Fields{
1444                 "filename": fnm,
1445                 "rows":     rows,
1446                 "cols":     cols,
1447                 "bytes":    rows * cols * 4,
1448         }).Infof("writing numpy: %s", fnm)
1449         npw.Shape = []int{rows, cols}
1450         npw.WriteUint32(out)
1451         err = bufw.Flush()
1452         if err != nil {
1453                 return err
1454         }
1455         return output.Close()
1456 }
1457
1458 func writeNumpyInt32(fnm string, out []int32, rows, cols int) error {
1459         output, err := os.Create(fnm)
1460         if err != nil {
1461                 return err
1462         }
1463         defer output.Close()
1464         bufw := bufio.NewWriterSize(output, 1<<26)
1465         npw, err := gonpy.NewWriter(nopCloser{bufw})
1466         if err != nil {
1467                 return err
1468         }
1469         log.WithFields(log.Fields{
1470                 "filename": fnm,
1471                 "rows":     rows,
1472                 "cols":     cols,
1473                 "bytes":    rows * cols * 4,
1474         }).Infof("writing numpy: %s", fnm)
1475         npw.Shape = []int{rows, cols}
1476         npw.WriteInt32(out)
1477         err = bufw.Flush()
1478         if err != nil {
1479                 return err
1480         }
1481         return output.Close()
1482 }
1483
1484 func writeNumpyInt16(fnm string, out []int16, rows, cols int) error {
1485         output, err := os.Create(fnm)
1486         if err != nil {
1487                 return err
1488         }
1489         defer output.Close()
1490         bufw := bufio.NewWriterSize(output, 1<<26)
1491         npw, err := gonpy.NewWriter(nopCloser{bufw})
1492         if err != nil {
1493                 return err
1494         }
1495         log.WithFields(log.Fields{
1496                 "filename": fnm,
1497                 "rows":     rows,
1498                 "cols":     cols,
1499                 "bytes":    rows * cols * 2,
1500         }).Infof("writing numpy: %s", fnm)
1501         npw.Shape = []int{rows, cols}
1502         npw.WriteInt16(out)
1503         err = bufw.Flush()
1504         if err != nil {
1505                 return err
1506         }
1507         return output.Close()
1508 }
1509
1510 func writeNumpyInt8(fnm string, out []int8, rows, cols int) error {
1511         output, err := os.Create(fnm)
1512         if err != nil {
1513                 return err
1514         }
1515         defer output.Close()
1516         bufw := bufio.NewWriterSize(output, 1<<26)
1517         npw, err := gonpy.NewWriter(nopCloser{bufw})
1518         if err != nil {
1519                 return err
1520         }
1521         log.WithFields(log.Fields{
1522                 "filename": fnm,
1523                 "rows":     rows,
1524                 "cols":     cols,
1525                 "bytes":    rows * cols,
1526         }).Infof("writing numpy: %s", fnm)
1527         npw.Shape = []int{rows, cols}
1528         npw.WriteInt8(out)
1529         err = bufw.Flush()
1530         if err != nil {
1531                 return err
1532         }
1533         return output.Close()
1534 }
1535
1536 func allele2homhet(colpair [2][]int8) {
1537         a, b := colpair[0], colpair[1]
1538         for i, av := range a {
1539                 bv := b[i]
1540                 if av < 0 || bv < 0 {
1541                         // no-call
1542                         a[i], b[i] = -1, -1
1543                 } else if av > 0 && bv > 0 {
1544                         // hom
1545                         a[i], b[i] = 1, 0
1546                 } else if av > 0 || bv > 0 {
1547                         // het
1548                         a[i], b[i] = 0, 1
1549                 } else {
1550                         // ref (or a different variant in same position)
1551                         // (this is a no-op) a[i], b[i] = 0, 0
1552                 }
1553         }
1554 }
1555
1556 type onehotXref struct {
1557         tag     tagID
1558         variant tileVariantID
1559         hom     bool
1560         pvalue  float64
1561 }
1562
1563 const onehotXrefSize = unsafe.Sizeof(onehotXref{})
1564
1565 // Build onehot matrix (m[tileVariantIndex][genome] == 0 or 1) for all
1566 // variants of a single tile/tag#.
1567 //
1568 // Return nil if no tile variant passes Χ² filter.
1569 func (cmd *sliceNumpy) tv2homhet(cgs map[string]CompactGenome, maxv tileVariantID, remap []tileVariantID, tag, chunkstarttag tagID, seq map[tagID][]TileVariant) ([][]int8, []onehotXref) {
1570         if tag == cmd.debugTag {
1571                 tv := make([]tileVariantID, len(cmd.cgnames)*2)
1572                 for i, name := range cmd.cgnames {
1573                         copy(tv[i*2:(i+1)*2], cgs[name].Variants[(tag-chunkstarttag)*2:])
1574                 }
1575                 log.WithFields(logrus.Fields{
1576                         "cgs[i].Variants[tag*2+j]": tv,
1577                         "maxv":                     maxv,
1578                         "remap":                    remap,
1579                         "tag":                      tag,
1580                         "chunkstarttag":            chunkstarttag,
1581                 }).Info("tv2homhet()")
1582         }
1583         if maxv < 1 || (maxv < 2 && !cmd.includeVariant1) {
1584                 // everyone has the most common variant (of the variants we don't drop)
1585                 return nil, nil
1586         }
1587         tagoffset := tag - chunkstarttag
1588         coverage := 0
1589         for _, cg := range cgs {
1590                 alleles := 0
1591                 for _, v := range cg.Variants[tagoffset*2 : tagoffset*2+2] {
1592                         if v > 0 && int(v) < len(seq[tag]) && len(seq[tag][v].Sequence) > 0 {
1593                                 alleles++
1594                         }
1595                 }
1596                 if alleles == 2 {
1597                         coverage++
1598                 }
1599         }
1600         if coverage < cmd.minCoverage {
1601                 return nil, nil
1602         }
1603         // "observed" array for p-value calculation (training set
1604         // only)
1605         obs := make([][]bool, (maxv+1)*2) // 2 slices (hom + het) for each variant#
1606         // one-hot output (all samples)
1607         outcols := make([][]int8, (maxv+1)*2)
1608         for i := range obs {
1609                 obs[i] = make([]bool, cmd.trainingSetSize)
1610                 outcols[i] = make([]int8, len(cmd.cgnames))
1611         }
1612         for cgid, name := range cmd.cgnames {
1613                 tsid := cmd.trainingSet[cgid]
1614                 cgvars := cgs[name].Variants[tagoffset*2:]
1615                 tv0, tv1 := remap[cgvars[0]], remap[cgvars[1]]
1616                 for v := tileVariantID(1); v <= maxv; v++ {
1617                         if tv0 == v && tv1 == v {
1618                                 if tsid >= 0 {
1619                                         obs[v*2][tsid] = true
1620                                 }
1621                                 outcols[v*2][cgid] = 1
1622                         } else if tv0 == v || tv1 == v {
1623                                 if tsid >= 0 {
1624                                         obs[v*2+1][tsid] = true
1625                                 }
1626                                 outcols[v*2+1][cgid] = 1
1627                         }
1628                 }
1629         }
1630         var onehot [][]int8
1631         var xref []onehotXref
1632         for col := 2; col < len(obs); col++ {
1633                 // col 0,1 correspond to tile variant 0, i.e.,
1634                 // no-call; col 2,3 correspond to the most common
1635                 // variant; so we (normally) start at col 4.
1636                 if col < 4 && !cmd.includeVariant1 {
1637                         continue
1638                 }
1639                 atomic.AddInt64(&cmd.pvalueCallCount, 1)
1640                 p := cmd.pvalue(obs[col])
1641                 if cmd.chi2PValue < 1 && !(p < cmd.chi2PValue) {
1642                         continue
1643                 }
1644                 onehot = append(onehot, outcols[col])
1645                 xref = append(xref, onehotXref{
1646                         tag:     tag,
1647                         variant: tileVariantID(col >> 1),
1648                         hom:     col&1 == 0,
1649                         pvalue:  p,
1650                 })
1651         }
1652         return onehot, xref
1653 }
1654
1655 // convert a []onehotXref with length N to a numpy-style []int32
1656 // matrix with N columns, one row per field of onehotXref struct.
1657 //
1658 // Hom/het row contains hom=0, het=1.
1659 //
1660 // P-value row contains 1000000x actual p-value.
1661 func onehotXref2int32(xrefs []onehotXref) []int32 {
1662         xcols := len(xrefs)
1663         xdata := make([]int32, 5*xcols)
1664         for i, xref := range xrefs {
1665                 xdata[i] = int32(xref.tag)
1666                 xdata[xcols+i] = int32(xref.variant)
1667                 if xref.hom {
1668                         xdata[xcols*2+i] = 1
1669                 }
1670                 xdata[xcols*3+i] = int32(xref.pvalue * 1000000)
1671                 xdata[xcols*4+i] = int32(-math.Log10(xref.pvalue) * 1000000)
1672         }
1673         return xdata
1674 }
1675
1676 // transpose onehot data from in[col][row] to numpy-style
1677 // out[row*cols+col].
1678 func onehotcols2int8(in [][]int8) []int8 {
1679         if len(in) == 0 {
1680                 return nil
1681         }
1682         cols := len(in)
1683         rows := len(in[0])
1684         out := make([]int8, rows*cols)
1685         for row := 0; row < rows; row++ {
1686                 outrow := out[row*cols:]
1687                 for col, incol := range in {
1688                         outrow[col] = incol[row]
1689                 }
1690         }
1691         return out
1692 }
1693
1694 // Return [2][]uint32{rowIndices, colIndices} indicating which
1695 // elements of matrixT[c][r] have non-zero values.
1696 func onehotChunk2Indirect(matrixT [][]int8) [2][]uint32 {
1697         var nz [2][]uint32
1698         for c, col := range matrixT {
1699                 for r, val := range col {
1700                         if val != 0 {
1701                                 nz[0] = append(nz[0], uint32(r))
1702                                 nz[1] = append(nz[1], uint32(c))
1703                         }
1704                 }
1705         }
1706         return nz
1707 }