19566: Option to limit pca components used in glm. Fix onehot use.
[lightning.git] / glm_test.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "fmt"
9         "math/rand"
10
11         "gopkg.in/check.v1"
12 )
13
14 type glmSuite struct{}
15
16 var _ = check.Suite(&glmSuite{})
17
18 func (s *glmSuite) TestPvalue(c *check.C) {
19         c.Check(pvalueGLM([]sampleInfo{
20                 {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{-4, 1.2, -3}},
21                 {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{7, -1.2, 2}},
22                 {id: "sample3", isCase: true, isTraining: true, pcaComponents: []float64{7, -1.2, 2}},
23                 {id: "sample4", isCase: true, isTraining: true, pcaComponents: []float64{-4, 1.1, -2}},
24         }, []bool{
25                 false, false, true, true,
26         }, 3), check.Equals, 0.09589096738494937)
27
28         c.Check(pvalueGLM([]sampleInfo{
29                 {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.21, 2.37}},
30                 {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.22, 2.38}},
31                 {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{3, 1.23, 2.39}},
32                 {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.24, 2.33}},
33                 {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.25, 2.34}},
34                 {id: "sample6", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.26, 2.35}},
35                 {id: "sample7", isCase: true, isTraining: true, pcaComponents: []float64{1, 1.23, 2.36}},
36                 {id: "sample8", isCase: true, isTraining: true, pcaComponents: []float64{2, 1.22, 2.32}},
37                 {id: "sample9", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.21, 2.31}},
38         }, []bool{
39                 false, false, false, false, false, true, true, true, true,
40         }, 3), check.Equals, 0.001028375654911555)
41
42         c.Check(pvalueGLM([]sampleInfo{
43                 {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1.001, -1.01, 2.39}},
44                 {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{1.002, -1.02, 2.38}},
45                 {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{1.003, -1.03, 2.37}},
46                 {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1.004, -1.04, 2.36}},
47                 {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{1.005, -1.05, 2.35}},
48                 {id: "sample6", isCase: false, isTraining: true, pcaComponents: []float64{1.006, -1.06, 2.34}},
49                 {id: "sample7", isCase: false, isTraining: true, pcaComponents: []float64{1.007, -1.07, 2.33}},
50                 {id: "sample8", isCase: false, isTraining: true, pcaComponents: []float64{1.008, -1.08, 2.32}},
51                 {id: "sample9", isCase: false, isTraining: false, pcaComponents: []float64{2.000, 8.01, -2.01}},
52                 {id: "sample10", isCase: true, isTraining: true, pcaComponents: []float64{2.001, 8.02, -2.02}},
53                 {id: "sample11", isCase: true, isTraining: true, pcaComponents: []float64{2.002, 8.03, -2.03}},
54                 {id: "sample12", isCase: true, isTraining: true, pcaComponents: []float64{2.003, 8.04, -2.04}},
55                 {id: "sample13", isCase: true, isTraining: true, pcaComponents: []float64{2.004, 8.05, -2.05}},
56                 {id: "sample14", isCase: true, isTraining: true, pcaComponents: []float64{2.005, 8.06, -2.06}},
57                 {id: "sample15", isCase: true, isTraining: true, pcaComponents: []float64{2.006, 8.07, -2.07}},
58                 {id: "sample16", isCase: true, isTraining: true, pcaComponents: []float64{2.007, 8.08, -2.08}},
59                 {id: "sample17", isCase: true, isTraining: true, pcaComponents: []float64{2.008, 8.09, -2.09}},
60                 {id: "sample18", isCase: true, isTraining: true, pcaComponents: []float64{2.009, 8.10, -2.10}},
61                 {id: "sample19", isCase: true, isTraining: true, pcaComponents: []float64{2.010, 8.11, -2.11}},
62         }, []bool{
63                 false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, true, true, true,
64         }, 3), check.Equals, 0.9999944849940106)
65 }
66
67 var benchSamples, benchOnehot = func() ([]sampleInfo, []bool) {
68         pcaComponents := 10
69         samples := []sampleInfo{}
70         onehot := []bool{}
71         r := make([]float64, pcaComponents)
72         for j := 0; j < 10000; j++ {
73                 for i := 0; i < len(r); i++ {
74                         r[i] = rand.Float64()
75                 }
76                 samples = append(samples, sampleInfo{
77                         id:            fmt.Sprintf("sample%d", j),
78                         isCase:        j%2 == 0 && j > 200,
79                         isControl:     j%2 == 1 || j <= 200,
80                         isTraining:    true,
81                         pcaComponents: append([]float64(nil), r...),
82                 })
83                 onehot = append(onehot, j%2 == 0)
84         }
85         return samples, onehot
86 }()
87
88 func (s *glmSuite) BenchmarkPvalue(c *check.C) {
89         for i := 0; i < c.N; i++ {
90                 p := pvalueGLM(benchSamples, benchOnehot, len(benchSamples[0].pcaComponents))
91                 c.Check(p, check.Equals, 0.0)
92         }
93 }