Fix some tests.
[lightning.git] / glm.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "fmt"
9         "io"
10         "log"
11         "math"
12
13         "github.com/kshedden/statmodel/glm"
14         "github.com/kshedden/statmodel/statmodel"
15         "gonum.org/v1/gonum/stat"
16         "gonum.org/v1/gonum/stat/distuv"
17 )
18
19 var glmConfig = &glm.Config{
20         Family:         glm.NewFamily(glm.BinomialFamily),
21         FitMethod:      "IRLS",
22         ConcurrentIRLS: 1000,
23         Log:            log.New(io.Discard, "", 0),
24 }
25
26 func normalize(a []float64) {
27         mean, std := stat.MeanStdDev(a, nil)
28         for i, x := range a {
29                 a[i] = (x - mean) / std
30         }
31 }
32
33 // Logistic regression.
34 //
35 // onehot is the observed outcome, in same order as sampleInfo, but
36 // shorter because it only has entries for samples with
37 // isTraining==true.
38 func glmPvalueFunc(sampleInfo []sampleInfo, nPCA int) func(onehot []bool) float64 {
39         pcaNames := make([]string, 0, nPCA)
40         data := make([][]statmodel.Dtype, 0, nPCA)
41         for pca := 0; pca < nPCA; pca++ {
42                 series := make([]statmodel.Dtype, 0, len(sampleInfo))
43                 for _, si := range sampleInfo {
44                         if si.isTraining {
45                                 series = append(series, si.pcaComponents[pca])
46                         }
47                 }
48                 normalize(series)
49                 data = append(data, series)
50                 pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
51         }
52
53         outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
54         constants := make([]statmodel.Dtype, 0, len(sampleInfo))
55         row := 0
56         for _, si := range sampleInfo {
57                 if si.isTraining {
58                         if si.isCase {
59                                 outcome = append(outcome, 1)
60                         } else {
61                                 outcome = append(outcome, 0)
62                         }
63                         constants = append(constants, 1)
64                         row++
65                 }
66         }
67         data = append([][]statmodel.Dtype{outcome, constants}, data...)
68         names := append([]string{"outcome", "constants"}, pcaNames...)
69         dataset := statmodel.NewDataset(data, names)
70
71         model, err := glm.NewGLM(dataset, "outcome", names[1:], glmConfig)
72         if err != nil {
73                 log.Printf("%s", err)
74                 return func([]bool) float64 { return math.NaN() }
75         }
76         resultCov := model.Fit()
77         logCov := resultCov.LogLike()
78
79         return func(onehot []bool) (p float64) {
80                 defer func() {
81                         if recover() != nil {
82                                 // typically "matrix singular or near-singular with condition number +Inf"
83                                 p = math.NaN()
84                         }
85                 }()
86
87                 variant := make([]statmodel.Dtype, 0, len(sampleInfo))
88                 row := 0
89                 for _, si := range sampleInfo {
90                         if si.isTraining {
91                                 if onehot[row] {
92                                         variant = append(variant, 1)
93                                 } else {
94                                         variant = append(variant, 0)
95                                 }
96                                 row++
97                         }
98                 }
99
100                 data := append([][]statmodel.Dtype{data[0], variant}, data[1:]...)
101                 names := append([]string{"outcome", "variant"}, names[1:]...)
102                 dataset := statmodel.NewDataset(data, names)
103
104                 model, err := glm.NewGLM(dataset, "outcome", names[1:], glmConfig)
105                 if err != nil {
106                         return math.NaN()
107                 }
108                 resultComp := model.Fit()
109                 logComp := resultComp.LogLike()
110                 dist := distuv.ChiSquared{K: 1}
111                 return dist.Survival(-2 * (logCov - logComp))
112         }
113 }