19566: Logistic regression p-value.
authorTom Clegg <tom@curii.com>
Tue, 29 Nov 2022 15:43:29 +0000 (10:43 -0500)
committerTom Clegg <tom@curii.com>
Tue, 29 Nov 2022 15:43:29 +0000 (10:43 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

glm.go [new file with mode: 0644]
glm_test.go [new file with mode: 0644]
go.mod
go.sum
slicenumpy.go

diff --git a/glm.go b/glm.go
new file mode 100644 (file)
index 0000000..cc06f39
--- /dev/null
+++ b/glm.go
@@ -0,0 +1,68 @@
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package lightning
+
+import (
+       "fmt"
+       "math"
+
+       "github.com/kshedden/statmodel/glm"
+       "github.com/kshedden/statmodel/statmodel"
+)
+
+var glmConfig = &glm.Config{
+       Family:         glm.NewFamily(glm.BinomialFamily),
+       FitMethod:      "IRLS",
+       ConcurrentIRLS: 1000,
+}
+
+func pvalueGLM(sampleInfo []sampleInfo, onehotPair [][]bool) float64 {
+       nPCA := len(sampleInfo[0].pcaComponents)
+       pcaNames := make([]string, 0, nPCA)
+       data := make([][]statmodel.Dtype, 0, nPCA)
+       for pca := 0; pca < nPCA; pca++ {
+               series := make([]statmodel.Dtype, 0, len(sampleInfo))
+               for _, si := range sampleInfo {
+                       if si.isTraining {
+                               series = append(series, si.pcaComponents[pca])
+                       }
+               }
+               data = append(data, series)
+               pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca))
+       }
+
+       variant := make([]statmodel.Dtype, 0, len(sampleInfo))
+       outcome := make([]statmodel.Dtype, 0, len(sampleInfo))
+       for row, si := range sampleInfo {
+               if si.isTraining {
+                       if onehotPair[0][row] {
+                               variant = append(variant, 1)
+                       } else {
+                               variant = append(variant, 0)
+                       }
+                       if si.isCase {
+                               outcome = append(outcome, 1)
+                       } else {
+                               outcome = append(outcome, 0)
+                       }
+               }
+       }
+       data = append(data, variant, outcome)
+
+       dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome"))
+       model, err := glm.NewGLM(dataset, "outcome", pcaNames, glmConfig)
+       if err != nil {
+               return math.NaN()
+       }
+       resultCov := model.Fit()
+       logCov := resultCov.LogLike()
+       model, err = glm.NewGLM(dataset, "outcome", append([]string{"variant"}, pcaNames...), glmConfig)
+       if err != nil {
+               return math.NaN()
+       }
+       resultComp := model.Fit()
+       logComp := resultComp.LogLike()
+       return chisquared.Survival(-2 * (logCov - logComp))
+}
diff --git a/glm_test.go b/glm_test.go
new file mode 100644 (file)
index 0000000..875a96c
--- /dev/null
@@ -0,0 +1,97 @@
+// Copyright (C) The Lightning Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package lightning
+
+import (
+       "fmt"
+       "math/rand"
+
+       "gopkg.in/check.v1"
+)
+
+type glmSuite struct{}
+
+var _ = check.Suite(&glmSuite{})
+
+func (s *glmSuite) TestPvalue(c *check.C) {
+       c.Check(pvalueGLM([]sampleInfo{
+               {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{-4, 1.2, -3}},
+               {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{7, -1.2, 2}},
+               {id: "sample3", isCase: true, isTraining: true, pcaComponents: []float64{7, -1.2, 2}},
+               {id: "sample4", isCase: true, isTraining: true, pcaComponents: []float64{-4, 1.1, -2}},
+       }, [][]bool{
+               {false, false, true, true},
+               {false, false, true, true},
+       }), check.Equals, 0.09589096738494937)
+
+       c.Check(pvalueGLM([]sampleInfo{
+               {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.21, 2.37}},
+               {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.22, 2.38}},
+               {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{3, 1.23, 2.39}},
+               {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.24, 2.33}},
+               {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.25, 2.34}},
+               {id: "sample6", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.26, 2.35}},
+               {id: "sample7", isCase: true, isTraining: true, pcaComponents: []float64{1, 1.23, 2.36}},
+               {id: "sample8", isCase: true, isTraining: true, pcaComponents: []float64{2, 1.22, 2.32}},
+               {id: "sample9", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.21, 2.31}},
+       }, [][]bool{
+               {false, false, false, false, false, true, true, true, true},
+               {false, false, false, false, false, true, true, true, true},
+       }), check.Equals, 0.001028375654911555)
+
+       c.Check(pvalueGLM([]sampleInfo{
+               {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1.001, -1.01, 2.39}},
+               {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{1.002, -1.02, 2.38}},
+               {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{1.003, -1.03, 2.37}},
+               {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1.004, -1.04, 2.36}},
+               {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{1.005, -1.05, 2.35}},
+               {id: "sample6", isCase: false, isTraining: true, pcaComponents: []float64{1.006, -1.06, 2.34}},
+               {id: "sample7", isCase: false, isTraining: true, pcaComponents: []float64{1.007, -1.07, 2.33}},
+               {id: "sample8", isCase: false, isTraining: true, pcaComponents: []float64{1.008, -1.08, 2.32}},
+               {id: "sample9", isCase: false, isTraining: false, pcaComponents: []float64{2.000, 8.01, -2.01}},
+               {id: "sample10", isCase: true, isTraining: true, pcaComponents: []float64{2.001, 8.02, -2.02}},
+               {id: "sample11", isCase: true, isTraining: true, pcaComponents: []float64{2.002, 8.03, -2.03}},
+               {id: "sample12", isCase: true, isTraining: true, pcaComponents: []float64{2.003, 8.04, -2.04}},
+               {id: "sample13", isCase: true, isTraining: true, pcaComponents: []float64{2.004, 8.05, -2.05}},
+               {id: "sample14", isCase: true, isTraining: true, pcaComponents: []float64{2.005, 8.06, -2.06}},
+               {id: "sample15", isCase: true, isTraining: true, pcaComponents: []float64{2.006, 8.07, -2.07}},
+               {id: "sample16", isCase: true, isTraining: true, pcaComponents: []float64{2.007, 8.08, -2.08}},
+               {id: "sample17", isCase: true, isTraining: true, pcaComponents: []float64{2.008, 8.09, -2.09}},
+               {id: "sample18", isCase: true, isTraining: true, pcaComponents: []float64{2.009, 8.10, -2.10}},
+               {id: "sample19", isCase: true, isTraining: true, pcaComponents: []float64{2.010, 8.11, -2.11}},
+       }, [][]bool{
+               {false, false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, true, true, true},
+               {false, false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, true, true, true},
+       }), check.Equals, 0.9999944849940106)
+}
+
+var benchSamples, benchOnehot = func() ([]sampleInfo, [][]bool) {
+       pcaComponents := 10
+       samples := []sampleInfo{}
+       onehot := make([][]bool, 2)
+       r := make([]float64, pcaComponents)
+       for j := 0; j < 10000; j++ {
+               for i := 0; i < len(r); i++ {
+                       r[i] = rand.Float64()
+               }
+               samples = append(samples, sampleInfo{
+                       id:            fmt.Sprintf("sample%d", j),
+                       isCase:        j%2 == 0 && j > 200,
+                       isControl:     j%2 == 1 || j <= 200,
+                       isTraining:    true,
+                       pcaComponents: append([]float64(nil), r...),
+               })
+               onehot[0] = append(onehot[0], j%2 == 0)
+               onehot[1] = append(onehot[1], j%2 == 0)
+       }
+       return samples, onehot
+}()
+
+func (s *glmSuite) BenchmarkPvalue(c *check.C) {
+       for i := 0; i < c.N; i++ {
+               p := pvalueGLM(benchSamples, benchOnehot)
+               c.Check(p, check.Equals, 0.0)
+       }
+}
diff --git a/go.mod b/go.mod
index 758e7d3e8df709def55c7b7dc635b860126f88f9..8f8400a3b8e351c5faf6fd92502f3463447f6b72 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
        github.com/james-bowman/nlp v0.0.0-20200417075118-1e2772e0e1e5
        github.com/klauspost/pgzip v1.2.5
        github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672
+       github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1
        github.com/mattn/go-isatty v0.0.12
        github.com/sergi/go-diff v1.1.0
        github.com/sirupsen/logrus v1.8.1
@@ -24,12 +25,14 @@ require (
        github.com/ghodss/yaml v1.0.0 // indirect
        github.com/gogo/protobuf v1.3.2 // indirect
        github.com/golang/protobuf v1.5.0 // indirect
+       github.com/golang/snappy v0.0.4 // indirect
        github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 // indirect
        github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect
        github.com/james-bowman/sparse v0.0.0-20200514124614-ae250424e52d // indirect
        github.com/klauspost/compress v1.15.11 // indirect
        github.com/kr/pretty v0.2.1 // indirect
        github.com/kr/text v0.1.0 // indirect
+       github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb // indirect
        github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
        github.com/prometheus/client_golang v1.7.1 // indirect
        github.com/prometheus/client_model v0.2.0 // indirect
@@ -38,6 +41,7 @@ require (
        github.com/spaolacci/murmur3 v1.1.0 // indirect
        github.com/stretchr/testify v1.6.1 // indirect
        golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect
+       golang.org/x/tools v0.1.7 // indirect
        google.golang.org/protobuf v1.27.1 // indirect
        gopkg.in/yaml.v2 v2.4.0 // indirect
 )
diff --git a/go.sum b/go.sum
index 20691b15d41443ba3217e8ba094cb5d935fc980e..158801e198e9241193dba85c650ad8bb08cc8c02 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -347,6 +347,8 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
 github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
 github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
+github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
 github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 h1:EvokxLQsaaQjcWVWSV38221VAK7qc2zhaO17bKys/18=
 github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82/go.mod h1:PxC8OnwL11+aosOB5+iEPoV3picfs8tUpkVd0pDo+Kg=
 github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0soOeia9UZSvYBvETVHZrugUowJ7M=
@@ -453,8 +455,12 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb h1:Z5BVHFk/DLOIUAd2NycF0mLtKfhl7ynm4Uy5+AFhT48=
+github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb/go.mod h1:+U+6yzfITr4/teU2YhxWhdyw6YzednT/16/UBMjlDrU=
 github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672 h1:LQLnybCU54zB8Gj8c1DPeZEheIAn3eZ8Cc9fYqM4ac8=
 github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672/go.mod h1:+uEXxXG0RlfBPqG1tq5QN/F2jRlcuY0dExSONLpEwcA=
+github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1 h1:UyIQ1VTQq/0CS/wLYjf3DV6uRKTd1xcsng3BccM4XCY=
+github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1/go.mod h1:uvVFnikBpVz7S1pdsyUI+BBRlz64vmU6Q+kviiB+fpU=
 github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
 github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
@@ -901,6 +907,7 @@ golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapK
 golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ=
 golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
index f3f9cc7526f3528ab7b6d12317e017d2cf938996..c555635585f2fc3a2f6aaead4227ff3a8aee104f 100644 (file)
@@ -83,12 +83,12 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
        onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk")
        samplesFilename := flags.String("samples", "", "`samples.csv` file with training/validation and case/control groups (see 'lightning choose-samples')")
        caseControlOnly := flags.Bool("case-control-only", false, "drop samples that are not in case/control groups")
-       onlyPCA := flags.Bool("pca", false, "generate pca matrix")
+       onlyPCA := flags.Bool("pca", false, "run principal component analysis, write components to pca.npy and samples.csv")
        pcaComponents := flags.Int("pca-components", 4, "number of PCA components")
        maxPCATiles := flags.Int("max-pca-tiles", 0, "maximum tiles to use as PCA input (filter, then drop every 2nd colum pair until below max)")
        debugTag := flags.Int("debug-tag", -1, "log debugging details about specified tag")
        flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads, and number of VCPUs to request for arvados container")
-       flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test and omit columns with p-value above this threshold")
+       flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test (or logistic regression if -samples file has PCA components) and omit columns with p-value above this threshold")
        flags.BoolVar(&cmd.includeVariant1, "include-variant-1", false, "include most common variant when building one-hot matrix")
        cmd.filter.Flags(flags)
        err := flags.Parse(args)
@@ -1353,12 +1353,23 @@ func (cmd *sliceNumpy) loadSampleInfo(samplesFilename string) ([]sampleInfo, err
                if idx != len(si) {
                        return nil, fmt.Errorf("%s line %d: index %d out of order", samplesFilename, lineNum, idx)
                }
+               var pcaComponents []float64
+               if len(split) > 4 {
+                       for _, s := range split[4:] {
+                               f, err := strconv.ParseFloat(s, 64)
+                               if err != nil {
+                                       return nil, fmt.Errorf("%s line %d: cannot parse float %q: %s", samplesFilename, lineNum, s, err)
+                               }
+                               pcaComponents = append(pcaComponents, f)
+                       }
+               }
                si = append(si, sampleInfo{
-                       id:           split[1],
-                       isCase:       split[2] == "1",
-                       isControl:    split[2] == "0",
-                       isTraining:   split[3] == "1",
-                       isValidation: split[3] == "0",
+                       id:            split[1],
+                       isCase:        split[2] == "1",
+                       isControl:     split[2] == "0",
+                       isTraining:    split[3] == "1",
+                       isValidation:  split[3] == "0",
+                       pcaComponents: pcaComponents,
                })
        }
        return si, nil
@@ -1590,7 +1601,12 @@ func (cmd *sliceNumpy) tv2homhet(cgs map[string]CompactGenome, maxv tileVariantI
                if col < 4 && !cmd.includeVariant1 {
                        continue
                }
-               p := pvalue(obs[col], cmd.chi2Cases)
+               var p float64
+               if len(cmd.samples[0].pcaComponents) > 0 {
+                       p = pvalueGLM(cmd.samples, obs[col:col+2])
+               } else {
+                       p = pvalue(obs[col], cmd.chi2Cases)
+               }
                if cmd.chi2PValue < 1 && !(p < cmd.chi2PValue) {
                        continue
                }