19566: Test p-value vs. Python.
authorTom Clegg <tom@curii.com>
Mon, 12 Dec 2022 16:37:21 +0000 (11:37 -0500)
committerTom Clegg <tom@curii.com>
Mon, 12 Dec 2022 16:40:00 +0000 (11:40 -0500)
Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom@curii.com>

glm_test.go
slicenumpy.go

index b8dbdb489949f10c631775c8b882d8c3a5af7a77..9761b7b4f0f10388adcfe7cabff32276a474f6d7 100644 (file)
@@ -42,8 +42,27 @@ func (s *glmSuite) TestFit(c *check.C) {
        c.Check(math.Abs(result.LogLike()-expect) < 0.00000000001, check.Equals, true)
 }
 
+var pyImports = `
+import scipy
+import statsmodels.formula.api as smf
+import statsmodels.api as sm
+import numpy as np
+import pandas as pd
+`
+
+func checkVirtualenv(c *check.C) {
+       cmd := exec.Command("python3", "-")
+       cmd.Stdin = strings.NewReader(pyImports)
+       out, err := cmd.CombinedOutput()
+       if err != nil {
+               c.Logf("%s", out)
+               c.Skip("test requires python virtualenv with libraries installed")
+       }
+}
+
 func (s *glmSuite) TestFitDivergeFromPython(c *check.C) {
-       c.Skip("test is slow, and requires a python virtualenv with libraries installed")
+       checkVirtualenv(c)
+       c.Skip("slow test")
        data0 := [][]statmodel.Dtype{
                {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
                {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 1, 4, 3, 6, 5, 7, 8, 9},
@@ -96,11 +115,7 @@ func (s *glmSuite) TestFitDivergeFromPython(c *check.C) {
                        pydata += "]"
                }
                pydata += "]"
-               py := `
-import statsmodels.formula.api as smf
-import statsmodels.api as sm
-import numpy as np
-import pandas as pd
+               py := pyImports + `
 data = np.array(` + pydata + `)
 df = pd.DataFrame(data.T, columns=['y','x1'])
 fit = smf.glm('y ~ x1', family=sm.families.Binomial(), data=df).fit()
@@ -119,6 +134,101 @@ print(fit.llf)
        }
 }
 
+func (s *glmSuite) TestPvalueRealDataVsPython(c *check.C) {
+       checkVirtualenv(c)
+       samples, err := loadSampleInfo("glm_test_samples.csv")
+       if err != nil {
+               c.Skip("test requires glm_test_samples.csv (not included)")
+       }
+       c.Logf("Nsamples = %d", len(samples))
+       nPCA := 5
+       // data series: y, rand, pca1, ..., pcaN
+       data := [][]statmodel.Dtype{nil, nil}
+       for i := 0; i < nPCA; i++ {
+               data = append(data, []statmodel.Dtype(nil))
+       }
+       onehot := []bool{}
+       for _, si := range samples {
+               if !si.isTraining {
+                       continue
+               }
+               if si.isCase {
+                       data[0] = append(data[0], 1)
+               } else {
+                       data[0] = append(data[0], 0)
+               }
+               r := rand.Int()&1 == 1
+               if rand.Int()&0x1f == 0 {
+                       // 1/32 samples have onehot==outcome, the rest
+                       // are random
+                       r = si.isCase
+               }
+               onehot = append(onehot, r)
+               if r {
+                       data[1] = append(data[1], 1)
+               } else {
+                       data[1] = append(data[1], 0)
+               }
+               for i := 0; i < nPCA; i++ {
+                       data[i+2] = append(data[i+2], si.pcaComponents[i])
+               }
+       }
+
+       pGo := pvalueGLM(samples, onehot, nPCA)
+       c.Logf("pGo = %g", pGo)
+
+       var pydata bytes.Buffer
+       pydata.WriteString("[")
+       for row, values := range data {
+               if row > 0 {
+                       pydata.WriteString(",")
+               }
+               pydata.WriteString("\n    [")
+               for col, v := range values {
+                       if col > 0 {
+                               pydata.WriteString(", ")
+                       }
+                       fmt.Fprintf(&pydata, "%v", v)
+               }
+               pydata.WriteString("]")
+       }
+       pydata.WriteString("]")
+       py := pyImports + `
+data = np.array(` + pydata.String() + `)
+columns = ['y','onehot']
+formula = ''
+for i in range(` + fmt.Sprintf("%d", nPCA) + `):
+    columns.append('x'+str(i+1))
+    if len(formula) > 0:
+        formula += ' + '
+    formula += 'x'+str(i+1)
+df = pd.DataFrame(data.T, columns=columns)
+
+mod1 = smf.glm('y ~ '+formula, family=sm.families.Binomial(), data=df).fit()
+# print(mod1.summary())
+print('mod1.llf = ', mod1.llf)
+
+mod2 = smf.glm('y ~ onehot + '+formula, family=sm.families.Binomial(), data=df).fit()
+# print(mod2.summary())
+print('mod2.llf = ', mod2.llf)
+
+df = 1
+p = 1 - scipy.stats.chi2.cdf(-2 * (mod1.llf - mod2.llf), df)
+print(p)
+`
+       c.Logf("python...")
+       cmd := exec.Command("python3", "-")
+       cmd.Stdin = strings.NewReader(py)
+       out, err := cmd.CombinedOutput()
+       c.Logf("%s", out)
+       c.Assert(err, check.IsNil)
+       outlines := bytes.Split(out, []byte{'\n'})
+       pPy, err := strconv.ParseFloat(string(outlines[len(outlines)-2]), 64)
+       c.Assert(err, check.IsNil)
+       c.Logf("pPy = %g", pPy)
+       c.Assert(math.Abs(pGo-pPy) < 0.000001, check.Equals, true)
+}
+
 func (s *glmSuite) TestPvalue(c *check.C) {
        // csv: casecontrol,onehot,pca1,pca2,...
        csv2test := func(csv string) (samples []sampleInfo, onehot []bool, npca int) {
index a718cbf8148035f9105d216c9f22a1c4e33931b3..fb3ae95187c5289e4a525caf9fc58e410838f666 100644 (file)
@@ -183,7 +183,7 @@ func (cmd *sliceNumpy) run(prog string, args []string, stdin io.Reader, stdout,
        }
 
        if *samplesFilename != "" {
-               cmd.samples, err = cmd.loadSampleInfo(*samplesFilename)
+               cmd.samples, err = loadSampleInfo(*samplesFilename)
                if err != nil {
                        return err
                }
@@ -1320,7 +1320,7 @@ type sampleInfo struct {
 
 // Read samples.csv file with case/control and training/validation
 // flags.
-func (cmd *sliceNumpy) loadSampleInfo(samplesFilename string) ([]sampleInfo, error) {
+func loadSampleInfo(samplesFilename string) ([]sampleInfo, error) {
        var si []sampleInfo
        f, err := open(samplesFilename)
        if err != nil {