19566: Option to limit pca components used in glm. Fix onehot use.
[lightning.git] / glm.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "fmt"
9         "math"
10
11         "github.com/kshedden/statmodel/glm"
12         "github.com/kshedden/statmodel/statmodel"
13 )
14
15 var glmConfig = &glm.Config{
16         Family:         glm.NewFamily(glm.BinomialFamily),
17         FitMethod:      "IRLS",
18         ConcurrentIRLS: 1000,
19 }
20
21 // Logistic regression.
22 //
23 // onehot is the observed outcome, in same order as sampleInfo, but
24 // shorter because it only has entries for samples with
25 // isTraining==true.
26 func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) float64 {
27         pcaNames := make([]string, 0, nPCA)
28         data := make([][]statmodel.Dtype, 0, nPCA)
29         for pca := 0; pca < nPCA; pca++ {
30                 series := make([]statmodel.Dtype, 0, len(sampleInfo))
31                 for _, si := range sampleInfo {
32                         if si.isTraining {
33                                 series = append(series, si.pcaComponents[pca])
34                         }
35                 }
36                 data = append(data, series)
37                 pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
38         }
39
40         variant := make([]statmodel.Dtype, 0, len(sampleInfo))
41         outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
42         row := 0
43         for _, si := range sampleInfo {
44                 if si.isTraining {
45                         if onehot[row] {
46                                 variant = append(variant, 1)
47                         } else {
48                                 variant = append(variant, 0)
49                         }
50                         if si.isCase {
51                                 outcome = append(outcome, 1)
52                         } else {
53                                 outcome = append(outcome, 0)
54                         }
55                         row++
56                 }
57         }
58         data = append(data, variant, outcome)
59
60         dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome"))
61         model, err := glm.NewGLM(dataset, "outcome", pcaNames, glmConfig)
62         if err != nil {
63                 return math.NaN()
64         }
65         resultCov := model.Fit()
66         logCov := resultCov.LogLike()
67         model, err = glm.NewGLM(dataset, "outcome", append([]string{"variant"}, pcaNames...), glmConfig)
68         if err != nil {
69                 return math.NaN()
70         }
71         resultComp := model.Fit()
72         logComp := resultComp.LogLike()
73         return chisquared.Survival(-2 * (logCov - logComp))
74 }