Merge branch '19566-glm'
[lightning.git] / exportnumpy_test.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bytes"
9         "io/ioutil"
10         "os"
11
12         "github.com/kshedden/gonpy"
13         "gopkg.in/check.v1"
14 )
15
16 type exportNumpySuite struct{}
17
18 var _ = check.Suite(&exportNumpySuite{})
19
20 func (s *exportNumpySuite) TestFastaToNumpy(c *check.C) {
21         tmpdir := c.MkDir()
22
23         err := ioutil.WriteFile(tmpdir+"/chr1-12-100.bed", []byte("chr1\t12\t100\ttest.1\n"), 0644)
24         c.Check(err, check.IsNil)
25
26         var buffer bytes.Buffer
27         exited := (&importer{}).RunCommand("import", []string{"-local=true", "-o", tmpdir + "/library.gob.gz", "-tag-library", "testdata/tags", "-output-tiles", "-save-incomplete-tiles", "testdata/a.1.fasta", "testdata/tinyref.fasta"}, &bytes.Buffer{}, os.Stderr, os.Stderr)
28         c.Assert(exited, check.Equals, 0)
29         exited = (&exportNumpy{}).RunCommand("export-numpy", []string{"-local=true", "-input-dir", tmpdir, "-output-dir", tmpdir, "-output-annotations", tmpdir + "/annotations.csv", "-regions", tmpdir + "/chr1-12-100.bed"}, &buffer, os.Stderr, os.Stderr)
30         c.Check(exited, check.Equals, 0)
31         f, err := os.Open(tmpdir + "/matrix.npy")
32         c.Assert(err, check.IsNil)
33         defer f.Close()
34         npy, err := gonpy.NewReader(f)
35         c.Assert(err, check.IsNil)
36         variants, err := npy.GetInt16()
37         c.Assert(err, check.IsNil)
38         c.Check(variants, check.HasLen, 6)
39         for i := 0; i < 4 && i < len(variants); i += 2 {
40                 if variants[i] == 1 {
41                         c.Check(variants[i+1], check.Equals, int16(2), check.Commentf("i=%d, v=%v", i, variants))
42                 } else {
43                         c.Check(variants[i], check.Equals, int16(2), check.Commentf("i=%d, v=%v", i, variants))
44                 }
45         }
46         for i := 4; i < 6 && i < len(variants); i += 2 {
47                 c.Check(variants[i], check.Equals, int16(1), check.Commentf("i=%d, v=%v", i, variants))
48         }
49         annotations, err := ioutil.ReadFile(tmpdir + "/annotations.csv")
50         c.Check(err, check.IsNil)
51         c.Logf("%s", string(annotations))
52         c.Check(string(annotations), check.Matches, `(?ms)(.*\n)?1,1,2,chr1:g.84_85insACTGCGATCTGA\n.*`)
53         c.Check(string(annotations), check.Matches, `(?ms)(.*\n)?1,1,1,chr1:g.87_96delinsGCATCTGCA\n.*`)
54 }
55
56 func sortUints(variants []int16) {
57         for i := 0; i < len(variants); i += 2 {
58                 if variants[i] > variants[i+1] {
59                         for j := 0; j < len(variants); j++ {
60                                 variants[j], variants[j+1] = variants[j+1], variants[j]
61                         }
62                         return
63                 }
64         }
65 }
66
67 func (s *exportNumpySuite) TestOnehot(c *check.C) {
68         for _, trial := range []struct {
69                 incols  int
70                 in      []int16
71                 outcols int
72                 out     []int16
73         }{
74                 {2, []int16{1, 1, 1, 1}, 2, []int16{1, 1, 1, 1}},
75                 {2, []int16{1, 1, 1, 2}, 3, []int16{1, 1, 0, 1, 0, 1}},
76                 {
77                         // 2nd column => 3 one-hot columns
78                         // 4th column => 0 one-hot columns
79                         4, []int16{
80                                 1, 1, 0, 0,
81                                 1, 2, 1, 0,
82                                 1, 3, 0, 0,
83                         }, 5, []int16{
84                                 1, 1, 0, 0, 0,
85                                 1, 0, 1, 0, 1,
86                                 1, 0, 0, 1, 0,
87                         },
88                 },
89         } {
90                 out, _, outcols := recodeOnehot(trial.in, trial.incols)
91                 c.Check(out, check.DeepEquals, trial.out)
92                 c.Check(outcols, check.Equals, trial.outcols)
93         }
94 }