19566: Add constant, check GLM results against Python.
[lightning.git] / glm.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "fmt"
9         "io"
10         "log"
11         "math"
12
13         "github.com/kshedden/statmodel/glm"
14         "github.com/kshedden/statmodel/statmodel"
15         "gonum.org/v1/gonum/stat/distuv"
16 )
17
18 var glmConfig = &glm.Config{
19         Family:         glm.NewFamily(glm.BinomialFamily),
20         FitMethod:      "IRLS",
21         ConcurrentIRLS: 1000,
22         Log:            log.New(io.Discard, "", 0),
23 }
24
25 // Logistic regression.
26 //
27 // onehot is the observed outcome, in same order as sampleInfo, but
28 // shorter because it only has entries for samples with
29 // isTraining==true.
30 func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
31         pcaNames := make([]string, 0, nPCA)
32         data := make([][]statmodel.Dtype, 0, nPCA)
33         for pca := 0; pca < nPCA; pca++ {
34                 series := make([]statmodel.Dtype, 0, len(sampleInfo))
35                 for _, si := range sampleInfo {
36                         if si.isTraining {
37                                 series = append(series, si.pcaComponents[pca])
38                         }
39                 }
40                 data = append(data, series)
41                 pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
42         }
43
44         variant := make([]statmodel.Dtype, 0, len(sampleInfo))
45         outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
46         constants := make([]statmodel.Dtype, 0, len(sampleInfo))
47         row := 0
48         for _, si := range sampleInfo {
49                 if si.isTraining {
50                         if onehot[row] {
51                                 variant = append(variant, 1)
52                         } else {
53                                 variant = append(variant, 0)
54                         }
55                         if si.isCase {
56                                 outcome = append(outcome, 1)
57                         } else {
58                                 outcome = append(outcome, 0)
59                         }
60                         constants = append(constants, 1)
61                         row++
62                 }
63         }
64         data = append(data, variant, outcome, constants)
65         dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome", "constants"))
66
67         defer func() {
68                 if recover() != nil {
69                         // typically "matrix singular or near-singular with condition number +Inf"
70                         p = math.NaN()
71                 }
72         }()
73         model, err := glm.NewGLM(dataset, "outcome", append([]string{"constants"}, pcaNames...), glmConfig)
74         if err != nil {
75                 return math.NaN()
76         }
77         resultCov := model.Fit()
78         logCov := resultCov.LogLike()
79         model, err = glm.NewGLM(dataset, "outcome", append([]string{"constants", "variant"}, pcaNames...), glmConfig)
80         if err != nil {
81                 return math.NaN()
82         }
83         resultComp := model.Fit()
84         logComp := resultComp.LogLike()
85         dist := distuv.ChiSquared{K: 1}
86         return dist.Survival(-2 * (logCov - logComp))
87 }