Merge branch '19785-add-cwl' into main
[lightning.git] / glm.go
diff --git a/glm.go b/glm.go
index 35ed37891898f989dc2619c13f17e8a74c1d6ae0..6ae98df928ac02a64d5381ff0a4f5bd8bfb0bfe5 100644 (file)
--- a/glm.go
+++ b/glm.go
@@ -6,20 +6,36 @@ package lightning
 
 import (
        "fmt"
+       "io"
+       "log"
        "math"
 
        "github.com/kshedden/statmodel/glm"
        "github.com/kshedden/statmodel/statmodel"
+       "gonum.org/v1/gonum/stat"
+       "gonum.org/v1/gonum/stat/distuv"
 )
 
 var glmConfig = &glm.Config{
        Family:         glm.NewFamily(glm.BinomialFamily),
        FitMethod:      "IRLS",
        ConcurrentIRLS: 1000,
+       Log:            log.New(io.Discard, "", 0),
 }
 
-func pvalueGLM(sampleInfo []sampleInfo, onehot []bool) float64 {
-       nPCA := len(sampleInfo[0].pcaComponents)
+func normalize(a []float64) {
+       mean, std := stat.MeanStdDev(a, nil)
+       for i, x := range a {
+               a[i] = (x - mean) / std
+       }
+}
+
+// Logistic regression.
+//
+// onehot is the observed outcome, in same order as sampleInfo, but
+// shorter because it only has entries for samples with
+// isTraining==true.
+func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
        pcaNames := make([]string, 0, nPCA)
        data := make([][]statmodel.Dtype, 0, nPCA)
        for pca := 0; pca < nPCA; pca++ {
@@ -29,13 +45,16 @@ func pvalueGLM(sampleInfo []sampleInfo, onehot []bool) float64 {
                                series = append(series, si.pcaComponents[pca])
                        }
                }
+               normalize(series)
                data = append(data, series)
                pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
        }
 
        variant := make([]statmodel.Dtype, 0, len(sampleInfo))
        outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
-       for row, si := range sampleInfo {
+       constants := make([]statmodel.Dtype, 0, len(sampleInfo))
+       row := 0
+       for _, si := range sampleInfo {
                if si.isTraining {
                        if onehot[row] {
                                variant = append(variant, 1)
@@ -47,22 +66,31 @@ func pvalueGLM(sampleInfo []sampleInfo, onehot []bool) float64 {
                        } else {
                                outcome = append(outcome, 0)
                        }
+                       constants = append(constants, 1)
+                       row++
                }
        }
-       data = append(data, variant, outcome)
+       data = append(data, variant, outcome, constants)
+       dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome", "constants"))
 
-       dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome"))
-       model, err := glm.NewGLM(dataset, "outcome", pcaNames, glmConfig)
+       defer func() {
+               if recover() != nil {
+                       // typically "matrix singular or near-singular with condition number +Inf"
+                       p = math.NaN()
+               }
+       }()
+       model, err := glm.NewGLM(dataset, "outcome", append([]string{"constants"}, pcaNames...), glmConfig)
        if err != nil {
                return math.NaN()
        }
        resultCov := model.Fit()
        logCov := resultCov.LogLike()
-       model, err = glm.NewGLM(dataset, "outcome", append([]string{"variant"}, pcaNames...), glmConfig)
+       model, err = glm.NewGLM(dataset, "outcome", append([]string{"constants", "variant"}, pcaNames...), glmConfig)
        if err != nil {
                return math.NaN()
        }
        resultComp := model.Fit()
        logComp := resultComp.LogLike()
-       return chisquared.Survival(-2 * (logCov - logComp))
+       dist := distuv.ChiSquared{K: 1}
+       return dist.Survival(-2 * (logCov - logComp))
 }