From b513c8156ae1a20207479cab714a91733187a4fd Mon Sep 17 00:00:00 2001 From: Tom Clegg Date: Tue, 29 Nov 2022 10:43:29 -0500 Subject: [PATCH] 19566: Logistic regression p-value. Arvados-DCO-1.1-Signed-off-by: Tom Clegg --- glm.go | 68 ++++++++++++++++++++++++++++++++++++ glm_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +++ go.sum | 7 ++++ slicenumpy.go | 32 ++++++++++++----- 5 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 glm.go create mode 100644 glm_test.go diff --git a/glm.go b/glm.go new file mode 100644 index 0000000000..cc06f39e2c --- /dev/null +++ b/glm.go @@ -0,0 +1,68 @@ +// Copyright (C) The Lightning Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + +package lightning + +import ( + "fmt" + "math" + + "github.com/kshedden/statmodel/glm" + "github.com/kshedden/statmodel/statmodel" +) + +var glmConfig = &glm.Config{ + Family: glm.NewFamily(glm.BinomialFamily), + FitMethod: "IRLS", + ConcurrentIRLS: 1000, +} + +func pvalueGLM(sampleInfo []sampleInfo, onehotPair [][]bool) float64 { + nPCA := len(sampleInfo[0].pcaComponents) + pcaNames := make([]string, 0, nPCA) + data := make([][]statmodel.Dtype, 0, nPCA) + for pca := 0; pca < nPCA; pca++ { + series := make([]statmodel.Dtype, 0, len(sampleInfo)) + for _, si := range sampleInfo { + if si.isTraining { + series = append(series, si.pcaComponents[pca]) + } + } + data = append(data, series) + pcaNames = append(pcaNames, fmt.Sprintf("pca%d", pca)) + } + + variant := make([]statmodel.Dtype, 0, len(sampleInfo)) + outcome := make([]statmodel.Dtype, 0, len(sampleInfo)) + for row, si := range sampleInfo { + if si.isTraining { + if onehotPair[0][row] { + variant = append(variant, 1) + } else { + variant = append(variant, 0) + } + if si.isCase { + outcome = append(outcome, 1) + } else { + outcome = append(outcome, 0) + } + } + } + data = append(data, variant, outcome) + + dataset := statmodel.NewDataset(data, append(pcaNames, "variant", "outcome")) + model, err := glm.NewGLM(dataset, "outcome", pcaNames, glmConfig) + if err != nil { + return math.NaN() + } + resultCov := model.Fit() + logCov := resultCov.LogLike() + model, err = glm.NewGLM(dataset, "outcome", append([]string{"variant"}, pcaNames...), glmConfig) + if err != nil { + return math.NaN() + } + resultComp := model.Fit() + logComp := resultComp.LogLike() + return chisquared.Survival(-2 * (logCov - logComp)) +} diff --git a/glm_test.go b/glm_test.go new file mode 100644 index 0000000000..875a96c1d5 --- /dev/null +++ b/glm_test.go @@ -0,0 +1,97 @@ +// Copyright (C) The Lightning Authors. All rights reserved. +// +// SPDX-License-Identifier: AGPL-3.0 + +package lightning + +import ( + "fmt" + "math/rand" + + "gopkg.in/check.v1" +) + +type glmSuite struct{} + +var _ = check.Suite(&glmSuite{}) + +func (s *glmSuite) TestPvalue(c *check.C) { + c.Check(pvalueGLM([]sampleInfo{ + {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{-4, 1.2, -3}}, + {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{7, -1.2, 2}}, + {id: "sample3", isCase: true, isTraining: true, pcaComponents: []float64{7, -1.2, 2}}, + {id: "sample4", isCase: true, isTraining: true, pcaComponents: []float64{-4, 1.1, -2}}, + }, [][]bool{ + {false, false, true, true}, + {false, false, true, true}, + }), check.Equals, 0.09589096738494937) + + c.Check(pvalueGLM([]sampleInfo{ + {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.21, 2.37}}, + {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.22, 2.38}}, + {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{3, 1.23, 2.39}}, + {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1, 1.24, 2.33}}, + {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{2, 1.25, 2.34}}, + {id: "sample6", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.26, 2.35}}, + {id: "sample7", isCase: true, isTraining: true, pcaComponents: []float64{1, 1.23, 2.36}}, + {id: "sample8", isCase: true, isTraining: true, pcaComponents: []float64{2, 1.22, 2.32}}, + {id: "sample9", isCase: true, isTraining: true, pcaComponents: []float64{3, 1.21, 2.31}}, + }, [][]bool{ + {false, false, false, false, false, true, true, true, true}, + {false, false, false, false, false, true, true, true, true}, + }), check.Equals, 0.001028375654911555) + + c.Check(pvalueGLM([]sampleInfo{ + {id: "sample1", isCase: false, isTraining: true, pcaComponents: []float64{1.001, -1.01, 2.39}}, + {id: "sample2", isCase: false, isTraining: true, pcaComponents: []float64{1.002, -1.02, 2.38}}, + {id: "sample3", isCase: false, isTraining: true, pcaComponents: []float64{1.003, -1.03, 2.37}}, + {id: "sample4", isCase: false, isTraining: true, pcaComponents: []float64{1.004, -1.04, 2.36}}, + {id: "sample5", isCase: false, isTraining: true, pcaComponents: []float64{1.005, -1.05, 2.35}}, + {id: "sample6", isCase: false, isTraining: true, pcaComponents: []float64{1.006, -1.06, 2.34}}, + {id: "sample7", isCase: false, isTraining: true, pcaComponents: []float64{1.007, -1.07, 2.33}}, + {id: "sample8", isCase: false, isTraining: true, pcaComponents: []float64{1.008, -1.08, 2.32}}, + {id: "sample9", isCase: false, isTraining: false, pcaComponents: []float64{2.000, 8.01, -2.01}}, + {id: "sample10", isCase: true, isTraining: true, pcaComponents: []float64{2.001, 8.02, -2.02}}, + {id: "sample11", isCase: true, isTraining: true, pcaComponents: []float64{2.002, 8.03, -2.03}}, + {id: "sample12", isCase: true, isTraining: true, pcaComponents: []float64{2.003, 8.04, -2.04}}, + {id: "sample13", isCase: true, isTraining: true, pcaComponents: []float64{2.004, 8.05, -2.05}}, + {id: "sample14", isCase: true, isTraining: true, pcaComponents: []float64{2.005, 8.06, -2.06}}, + {id: "sample15", isCase: true, isTraining: true, pcaComponents: []float64{2.006, 8.07, -2.07}}, + {id: "sample16", isCase: true, isTraining: true, pcaComponents: []float64{2.007, 8.08, -2.08}}, + {id: "sample17", isCase: true, isTraining: true, pcaComponents: []float64{2.008, 8.09, -2.09}}, + {id: "sample18", isCase: true, isTraining: true, pcaComponents: []float64{2.009, 8.10, -2.10}}, + {id: "sample19", isCase: true, isTraining: true, pcaComponents: []float64{2.010, 8.11, -2.11}}, + }, [][]bool{ + {false, false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, true, true, true}, + {false, false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, true, true, true}, + }), check.Equals, 0.9999944849940106) +} + +var benchSamples, benchOnehot = func() ([]sampleInfo, [][]bool) { + pcaComponents := 10 + samples := []sampleInfo{} + onehot := make([][]bool, 2) + r := make([]float64, pcaComponents) + for j := 0; j < 10000; j++ { + for i := 0; i < len(r); i++ { + r[i] = rand.Float64() + } + samples = append(samples, sampleInfo{ + id: fmt.Sprintf("sample%d", j), + isCase: j%2 == 0 && j > 200, + isControl: j%2 == 1 || j <= 200, + isTraining: true, + pcaComponents: append([]float64(nil), r...), + }) + onehot[0] = append(onehot[0], j%2 == 0) + onehot[1] = append(onehot[1], j%2 == 0) + } + return samples, onehot +}() + +func (s *glmSuite) BenchmarkPvalue(c *check.C) { + for i := 0; i < c.N; i++ { + p := pvalueGLM(benchSamples, benchOnehot) + c.Check(p, check.Equals, 0.0) + } +} diff --git a/go.mod b/go.mod index 758e7d3e8d..8f8400a3b8 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/james-bowman/nlp v0.0.0-20200417075118-1e2772e0e1e5 github.com/klauspost/pgzip v1.2.5 github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672 + github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1 github.com/mattn/go-isatty v0.0.12 github.com/sergi/go-diff v1.1.0 github.com/sirupsen/logrus v1.8.1 @@ -24,12 +25,14 @@ require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.0 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 // indirect github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect github.com/james-bowman/sparse v0.0.0-20200514124614-ae250424e52d // indirect github.com/klauspost/compress v1.15.11 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/kr/text v0.1.0 // indirect + github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/prometheus/client_golang v1.7.1 // indirect github.com/prometheus/client_model v0.2.0 // indirect @@ -38,6 +41,7 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/stretchr/testify v1.6.1 // indirect golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect + golang.org/x/tools v0.1.7 // indirect google.golang.org/protobuf v1.27.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 20691b15d4..158801e198 100644 --- a/go.sum +++ b/go.sum @@ -347,6 +347,8 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 h1:EvokxLQsaaQjcWVWSV38221VAK7qc2zhaO17bKys/18= github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82/go.mod h1:PxC8OnwL11+aosOB5+iEPoV3picfs8tUpkVd0pDo+Kg= github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0soOeia9UZSvYBvETVHZrugUowJ7M= @@ -453,8 +455,12 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb h1:Z5BVHFk/DLOIUAd2NycF0mLtKfhl7ynm4Uy5+AFhT48= +github.com/kshedden/dstream v0.0.0-20190512025041-c4c410631beb/go.mod h1:+U+6yzfITr4/teU2YhxWhdyw6YzednT/16/UBMjlDrU= github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672 h1:LQLnybCU54zB8Gj8c1DPeZEheIAn3eZ8Cc9fYqM4ac8= github.com/kshedden/gonpy v0.0.0-20190510000443-66c21fac4672/go.mod h1:+uEXxXG0RlfBPqG1tq5QN/F2jRlcuY0dExSONLpEwcA= +github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1 h1:UyIQ1VTQq/0CS/wLYjf3DV6uRKTd1xcsng3BccM4XCY= +github.com/kshedden/statmodel v0.0.0-20210519035403-ee97d3e48df1/go.mod h1:uvVFnikBpVz7S1pdsyUI+BBRlz64vmU6Q+kviiB+fpU= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= @@ -901,6 +907,7 @@ golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/slicenumpy.go b/slicenumpy.go index f3f9cc7526..c555635585 100644 --- a/slicenumpy.go +++ b/slicenumpy.go @@ -83,12 +83,12 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout, onehotChunked := flags.Bool("chunked-onehot", false, "generate one-hot tile-based matrix per input chunk") samplesFilename := flags.String("samples", "", "`samples.csv` file with training/validation and case/control groups (see 'lightning choose-samples')") caseControlOnly := flags.Bool("case-control-only", false, "drop samples that are not in case/control groups") - onlyPCA := flags.Bool("pca", false, "generate pca matrix") + onlyPCA := flags.Bool("pca", false, "run principal component analysis, write components to pca.npy and samples.csv") pcaComponents := flags.Int("pca-components", 4, "number of PCA components") maxPCATiles := flags.Int("max-pca-tiles", 0, "maximum tiles to use as PCA input (filter, then drop every 2nd colum pair until below max)") debugTag := flags.Int("debug-tag", -1, "log debugging details about specified tag") flags.IntVar(&cmd.threads, "threads", 16, "number of memory-hungry assembly threads, and number of VCPUs to request for arvados container") - flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test and omit columns with p-value above this threshold") + flags.Float64Var(&cmd.chi2PValue, "chi2-p-value", 1, "do Χ² test (or logistic regression if -samples file has PCA components) and omit columns with p-value above this threshold") flags.BoolVar(&cmd.includeVariant1, "include-variant-1", false, "include most common variant when building one-hot matrix") cmd.filter.Flags(flags) err := flags.Parse(args) @@ -1353,12 +1353,23 @@ func (cmd *sliceNumpy) loadSampleInfo(samplesFilename string) ([]sampleInfo, err if idx != len(si) { return nil, fmt.Errorf("%s line %d: index %d out of order", samplesFilename, lineNum, idx) } + var pcaComponents []float64 + if len(split) > 4 { + for _, s := range split[4:] { + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, fmt.Errorf("%s line %d: cannot parse float %q: %s", samplesFilename, lineNum, s, err) + } + pcaComponents = append(pcaComponents, f) + } + } si = append(si, sampleInfo{ - id: split[1], - isCase: split[2] == "1", - isControl: split[2] == "0", - isTraining: split[3] == "1", - isValidation: split[3] == "0", + id: split[1], + isCase: split[2] == "1", + isControl: split[2] == "0", + isTraining: split[3] == "1", + isValidation: split[3] == "0", + pcaComponents: pcaComponents, }) } return si, nil @@ -1590,7 +1601,12 @@ func (cmd *sliceNumpy) tv2homhet(cgs map[string]CompactGenome, maxv tileVariantI if col < 4 && !cmd.includeVariant1 { continue } - p := pvalue(obs[col], cmd.chi2Cases) + var p float64 + if len(cmd.samples[0].pcaComponents) > 0 { + p = pvalueGLM(cmd.samples, obs[col:col+2]) + } else { + p = pvalue(obs[col], cmd.chi2Cases) + } if cmd.chi2PValue < 1 && !(p < cmd.chi2PValue) { continue } -- 2.30.2