Fix glm test.
[lightning.git] / glm.go
diff --git a/glm.go b/glm.go
index 6ae98df928ac02a64d5381ff0a4f5bd8bfb0bfe5..b68bdd1898419ffab3c15d5516c8259ad4646073 100644 (file)
--- a/glm.go
+++ b/glm.go
@@ -35,7 +35,7 @@ func normalize(a []float64) {
 // onehot is the observed outcome, in same order as sampleInfo, but
 // shorter because it only has entries for samples with
 // isTraining==true.
-func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
+func glmPvalueFunc(sampleInfo []sampleInfo, nPCA int) func(onehot []bool) float64 {
        pcaNames := make([]string, 0, nPCA)
        data := make([][]statmodel.Dtype, 0, nPCA)
        for pca := 0; pca < nPCA; pca++ {
@@ -50,17 +50,11 @@ func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
                pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
        }
 
-       variant := make([]statmodel.Dtype, 0, len(sampleInfo))
        outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
        constants := make([]statmodel.Dtype, 0, len(sampleInfo))
        row := 0
        for _, si := range sampleInfo {
                if si.isTraining {
-                       if onehot[row] {
-                               variant = append(variant, 1)
-                       } else {
-                               variant = append(variant, 0)
-                       }
                        if si.isCase {
                                outcome = append(outcome, 1)
                        } else {
@@ -70,27 +64,50 @@ func pvalueGLM(sampleInfo []sampleInfo, onehot []bool, nPCA int) (p float64) {
                        row++
                }
        }
-       data = append(data, variant, outcome, constants)
-       dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome", "constants"))
+       data = append([][]statmodel.Dtype{outcome, constants}, data...)
+       names := append([]string{"outcome", "constants"}, pcaNames...)
+       dataset := statmodel.NewDataset(data, names)
 
-       defer func() {
-               if recover() != nil {
-                       // typically "matrix singular or near-singular with condition number +Inf"
-                       p = math.NaN()
-               }
-       }()
-       model, err := glm.NewGLM(dataset, "outcome", append([]string{"constants"}, pcaNames...), glmConfig)
+       model, err := glm.NewGLM(dataset, "outcome", names[1:], glmConfig)
        if err != nil {
-               return math.NaN()
+               log.Printf("%s", err)
+               return func([]bool) float64 { return math.NaN() }
        }
        resultCov := model.Fit()
        logCov := resultCov.LogLike()
-       model, err = glm.NewGLM(dataset, "outcome", append([]string{"constants", "variant"}, pcaNames...), glmConfig)
-       if err != nil {
-               return math.NaN()
+
+       return func(onehot []bool) (p float64) {
+               defer func() {
+                       if recover() != nil {
+                               // typically "matrix singular or near-singular with condition number +Inf"
+                               p = math.NaN()
+                       }
+               }()
+
+               variant := make([]statmodel.Dtype, 0, len(sampleInfo))
+               row := 0
+               for _, si := range sampleInfo {
+                       if si.isTraining {
+                               if onehot[row] {
+                                       variant = append(variant, 1)
+                               } else {
+                                       variant = append(variant, 0)
+                               }
+                               row++
+                       }
+               }
+
+               data := append([][]statmodel.Dtype{data[0], variant}, data[1:]...)
+               names := append([]string{"outcome", "variant"}, names[1:]...)
+               dataset := statmodel.NewDataset(data, names)
+
+               model, err := glm.NewGLM(dataset, "outcome", names[1:], glmConfig)
+               if err != nil {
+                       return math.NaN()
+               }
+               resultComp := model.Fit()
+               logComp := resultComp.LogLike()
+               dist := distuv.ChiSquared{K: 1}
+               return dist.Survival(-2 * (logCov - logComp))
        }
-       resultComp := model.Fit()
-       logComp := resultComp.LogLike()
-       dist := distuv.ChiSquared{K: 1}
-       return dist.Survival(-2 * (logCov - logComp))
 }