6ae98df928ac02a64d5381ff0a4f5bd8bfb0bfe5
[lightning.git] / glm.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "fmt"
9         "io"
10         "log"
11         "math"
12
13         "github.com/kshedden/statmodel/glm"
14         "github.com/kshedden/statmodel/statmodel"
15         "gonum.org/v1/gonum/stat"
16         "gonum.org/v1/gonum/stat/distuv"
17 )
18
19 var glmConfig = &glm.Config{
20         Family:         glm.NewFamily(glm.BinomialFamily),
21         FitMethod:      "IRLS",
22         ConcurrentIRLS: 1000,
23         Log:            log.New(io.Discard, "", 0),
24 }
25
26 func normalize(a []float64) {
27         mean, std := stat.MeanStdDev(a, nil)
28         for i, x := range a {
29                 a[i] = (x - mean) / std
30         }
31 }
32
33 // Logistic regression.
34 //
35 // onehot is the observed outcome, in same order as sampleInfo, but
36 // shorter because it only has entries for samples with
37 // isTraining==true.
38 func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
39         pcaNames := make([]string, 0, nPCA)
40         data := make([][]statmodel.Dtype, 0, nPCA)
41         for pca := 0; pca < nPCA; pca++ {
42                 series := make([]statmodel.Dtype, 0, len(sampleInfo))
43                 for _, si := range sampleInfo {
44                         if si.isTraining {
45                                 series = append(series, si.pcaComponents[pca])
46                         }
47                 }
48                 normalize(series)
49                 data = append(data, series)
50                 pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
51         }
52
53         variant := make([]statmodel.Dtype, 0, len(sampleInfo))
54         outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
55         constants := make([]statmodel.Dtype, 0, len(sampleInfo))
56         row := 0
57         for _, si := range sampleInfo {
58                 if si.isTraining {
59                         if onehot[row] {
60                                 variant = append(variant, 1)
61                         } else {
62                                 variant = append(variant, 0)
63                         }
64                         if si.isCase {
65                                 outcome = append(outcome, 1)
66                         } else {
67                                 outcome = append(outcome, 0)
68                         }
69                         constants = append(constants, 1)
70                         row++
71                 }
72         }
73         data = append(data, variant, outcome, constants)
74         dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome", "constants"))
75
76         defer func() {
77                 if recover() != nil {
78                         // typically "matrix singular or near-singular with condition number +Inf"
79                         p = math.NaN()
80                 }
81         }()
82         model, err := glm.NewGLM(dataset, "outcome", append([]string{"constants"}, pcaNames...), glmConfig)
83         if err != nil {
84                 return math.NaN()
85         }
86         resultCov := model.Fit()
87         logCov := resultCov.LogLike()
88         model, err = glm.NewGLM(dataset, "outcome", append([]string{"constants", "variant"}, pcaNames...), glmConfig)
89         if err != nil {
90                 return math.NaN()
91         }
92         resultComp := model.Fit()
93         logComp := resultComp.LogLike()
94         dist := distuv.ChiSquared{K: 1}
95         return dist.Survival(-2 * (logCov - logComp))
96 }