Merge branch '19868-pca-in-ml' into main
[lightning.git] / glm_test.go
1 // Copyright (C) The Lightning Authors. All rights reserved.
2 //
3 // SPDX-License-Identifier: AGPL-3.0
4
5 package lightning
6
7 import (
8         "bytes"
9         "fmt"
10         "math"
11         "math/rand"
12         "os/exec"
13         "strconv"
14         "strings"
15
16         "github.com/kshedden/statmodel/glm"
17         "github.com/kshedden/statmodel/statmodel"
18         "gopkg.in/check.v1"
19 )
20
21 type glmSuite struct{}
22
23 var _ = check.Suite(&glmSuite{})
24
25 func (s *glmSuite) TestFit(c *check.C) {
26         data := [][]statmodel.Dtype{
27                 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1},
28                 {17.99, 20.57, 19.69, 11.42, 20.29, 12.45, 18.25, 13.71, 13.0, 12.46, 16.02, 15.78, 19.17, 15.85, 13.73, 14.54, 14.68, 16.13, 19.81, 13.54, 13.08, 9.504, 15.34, 21.16, 16.65, 17.14, 14.58, 18.61, 15.3, 17.57, 18.63, 11.84, 17.02, 19.27, 16.13, 16.74, 14.25, 13.03, 14.99, 13.48, 13.44, 10.95, 19.07, 13.28, 13.17, 18.65, 8.196, 13.17, 12.05, 13.49, 11.76, 13.64, 11.94, 18.22, 15.1, 11.52, 19.21, 14.71, 13.05, 8.618, 10.17, 8.598, 14.25, 9.173, 12.68, 14.78, 9.465, 11.31, 9.029, 12.78, 18.94, 8.888, 17.2, 13.8, 12.31, 16.07, 13.53, 18.05, 20.18, 12.86, 11.45, 13.34, 25.22, 19.1, 12.0, 18.46, 14.48, 19.02, 12.36, 14.64, 14.62, 15.37, 13.27, 13.45, 15.06, 20.26, 12.18, 9.787, 11.6, 14.42, 13.61, 6.981, 12.18, 9.876, 10.49, 13.11, 11.64, 12.36, 22.27, 11.34, 9.777, 12.63, 14.26, 10.51, 8.726, 11.93, 8.95, 14.87, 15.78, 17.95, 11.41, 18.66, 24.25, 14.5, 13.37, 13.85, 13.61, 19.0, 15.1, 19.79, 12.19, 15.46, 16.16, 15.71, 18.45, 12.77, 11.71, 11.43, 14.95, 11.28, 9.738, 16.11, 11.43, 12.9, 10.75, 11.9, 11.8, 14.95, 14.44, 13.74, 13.0, 8.219, 9.731, 11.15, 13.15, 12.25, 17.68, 16.84, 12.06, 10.9, 11.75, 19.19, 19.59, 12.34, 23.27, 14.97, 10.8, 16.78, 17.47, 14.97, 12.32, 13.43, 15.46, 11.08, 10.66, 8.671, 9.904, 16.46, 13.01, 12.81, 27.22, 21.09, 15.7, 11.41, 15.28, 10.08, 18.31, 11.71, 11.81, 12.3, 14.22, 12.77, 9.72, 12.34, 14.86, 12.91, 13.77, 18.08, 19.18, 14.45, 12.23, 17.54, 23.29, 13.81, 12.47, 15.12, 9.876, 17.01, 13.11, 15.27, 20.58, 11.84, 28.11, 17.42, 14.19, 13.86, 11.89, 10.2, 19.8, 19.53, 13.65, 13.56, 10.18, 15.75, 13.27, 14.34, 10.44, 15.0, 12.62, 12.83, 17.05, 11.32, 11.22, 20.51, 9.567, 14.03, 23.21, 20.48, 14.22, 17.46, 13.64, 12.42, 11.3, 13.75, 19.4, 10.48, 13.2, 12.89, 10.65, 11.52, 20.94, 11.5, 19.73, 17.3, 19.45, 13.96, 19.55, 15.32, 15.66, 15.53, 20.31, 17.35, 17.29, 15.61, 17.19, 20.73, 10.6, 13.59, 12.87, 10.71, 14.29, 11.29, 21.75, 9.742, 17.93, 11.89, 11.33, 18.81, 13.59, 13.85, 19.16, 11.74, 19.4, 16.24, 12.89, 12.58, 11.94, 12.89, 11.26, 11.37, 14.41, 14.96, 12.95, 11.85, 12.72, 13.77, 10.91, 11.76, 14.26, 10.51, 19.53, 12.46, 20.09, 10.49, 11.46, 11.6, 13.2, 9.0, 13.5, 13.05, 11.7, 14.61, 12.76, 11.54, 8.597, 12.49, 12.18, 18.22, 9.042, 12.43, 10.25, 20.16, 12.86, 20.34, 12.2, 12.67, 14.11, 12.03, 16.27, 16.26, 16.03, 12.98, 11.22, 11.25, 12.3, 17.06, 12.99, 18.77, 10.05, 23.51, 14.42, 9.606, 11.06, 19.68, 11.71, 10.26, 12.06, 14.76, 11.47, 11.95, 11.66, 15.75, 25.73, 15.08, 11.14, 12.56, 13.05, 13.87, 8.878, 9.436, 12.54, 13.3, 12.76, 16.5, 13.4, 20.44, 20.2, 12.21, 21.71, 22.01, 16.35, 15.19, 21.37, 20.64, 13.69, 16.17, 10.57, 13.46, 13.66, 11.08, 11.27, 11.04, 12.05, 12.39, 13.28, 14.6, 12.21, 13.88, 11.27, 19.55, 10.26, 8.734, 15.49, 21.61, 12.1, 14.06, 13.51, 12.8, 11.06, 11.8, 17.91, 11.93, 12.96, 12.94, 12.34, 10.94, 16.14, 12.85, 17.99, 12.27, 11.36, 11.04, 9.397, 14.99, 15.13, 11.89, 9.405, 15.5, 12.7, 11.16, 11.57, 14.69, 11.61, 13.66, 9.742, 10.03, 10.48, 10.8, 11.13, 12.72, 14.9, 12.4, 20.18, 18.82, 14.86, 13.98, 12.87, 14.04, 13.85, 14.02, 10.97, 17.27, 13.78, 10.57, 18.03, 11.99, 17.75, 14.8, 14.53, 21.1, 11.87, 19.59, 12.0, 14.53, 12.62, 13.38, 11.63, 13.21, 13.0, 9.755, 17.08, 27.42, 14.4, 11.6, 13.17, 13.24, 13.14, 9.668, 17.6, 11.62, 9.667, 12.04, 14.92, 12.27, 10.88, 12.83, 14.2, 13.9, 11.49, 16.25, 12.16, 13.9, 13.47, 13.7, 15.73, 12.45, 14.64, 19.44, 11.68, 16.69, 12.25, 17.85, 18.01, 12.46, 13.16, 14.87, 12.65, 12.47, 18.49, 20.59, 15.04, 13.82, 12.54, 23.09, 9.268, 9.676, 12.22, 11.06, 16.3, 15.46, 11.74, 14.81, 13.4, 14.58, 15.05, 11.34, 18.31, 19.89, 12.88, 12.75, 9.295, 24.63, 11.26, 13.71, 9.847, 8.571, 13.46, 12.34, 13.94, 12.07, 11.75, 11.67, 13.68, 20.47, 10.96, 20.55, 14.27, 11.69, 7.729, 7.691, 11.54, 14.47, 14.74, 13.21, 13.87, 13.62, 10.32, 10.26, 9.683, 10.82, 10.86, 11.13, 12.77, 9.333, 12.88, 10.29, 10.16, 9.423, 14.59, 11.51, 14.05, 11.2, 15.22, 20.92, 21.56, 20.13, 16.6, 20.6, 7.76},
29                 {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
30         }
31         dataset := statmodel.NewDataset(data, []string{"y", "x1", "const"})
32         model, err := glm.NewGLM(dataset, "y", []string{"const", "x1"}, glmConfig)
33         c.Assert(err, check.IsNil)
34         result := model.Fit()
35         c.Logf("%s", result.Summary())
36         c.Logf("VCov\t%v", result.VCov())
37         c.Logf("Params\t%v", result.Params())
38         c.Logf("StdErr\t%v", result.StdErr())
39         c.Logf("ZScores\t%v", result.ZScores())
40         c.Logf("LogLike\t%v", result.LogLike())
41         expect := -165.00542199378245 // from python
42         c.Check(math.Abs(result.LogLike()-expect) < 0.00000000001, check.Equals, true)
43 }
44
45 var pyImports = `
46 import scipy
47 import statsmodels.formula.api as smf
48 import statsmodels.api as sm
49 import numpy as np
50 import pandas as pd
51 `
52
53 func checkVirtualenv(c *check.C) {
54         cmd := exec.Command("python3", "-")
55         cmd.Stdin = strings.NewReader(pyImports)
56         out, err := cmd.CombinedOutput()
57         if err != nil {
58                 c.Logf("%s", out)
59                 c.Skip("test requires python virtualenv with libraries installed")
60         }
61 }
62
63 func (s *glmSuite) TestFitDivergeFromPython(c *check.C) {
64         checkVirtualenv(c)
65         c.Skip("slow test")
66         data0 := [][]statmodel.Dtype{
67                 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
68                 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 1, 4, 3, 6, 5, 7, 8, 9},
69         }
70         data1 := [][]statmodel.Dtype{
71                 {1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1},
72                 {17.99, 20.57, 19.69, 11.42, 20.29, 12.45, 18.25, 13.71, 13.0, 12.46, 16.02, 15.78, 19.17, 15.85, 13.73, 14.54, 14.68, 16.13, 19.81, 13.54, 13.08, 9.504, 15.34, 21.16, 16.65, 17.14, 14.58, 18.61, 15.3, 17.57, 18.63, 11.84, 17.02, 19.27, 16.13, 16.74, 14.25, 13.03, 14.99, 13.48, 13.44, 10.95, 19.07, 13.28, 13.17, 18.65, 8.196, 13.17, 12.05, 13.49, 11.76, 13.64, 11.94, 18.22, 15.1, 11.52, 19.21, 14.71, 13.05, 8.618, 10.17, 8.598, 14.25, 9.173, 12.68, 14.78, 9.465, 11.31, 9.029, 12.78, 18.94, 8.888, 17.2, 13.8, 12.31, 16.07, 13.53, 18.05, 20.18, 12.86, 11.45, 13.34, 25.22, 19.1, 12.0, 18.46, 14.48, 19.02, 12.36, 14.64, 14.62, 15.37, 13.27, 13.45, 15.06, 20.26, 12.18, 9.787, 11.6, 14.42, 13.61, 6.981, 12.18, 9.876, 10.49, 13.11, 11.64, 12.36, 22.27, 11.34, 9.777, 12.63, 14.26, 10.51, 8.726, 11.93, 8.95, 14.87, 15.78, 17.95, 11.41, 18.66, 24.25, 14.5, 13.37, 13.85, 13.61, 19.0, 15.1, 19.79, 12.19, 15.46, 16.16, 15.71, 18.45, 12.77, 11.71, 11.43, 14.95, 11.28, 9.738, 16.11, 11.43, 12.9, 10.75, 11.9, 11.8, 14.95, 14.44, 13.74, 13.0, 8.219, 9.731, 11.15, 13.15, 12.25, 17.68, 16.84, 12.06, 10.9, 11.75, 19.19, 19.59, 12.34, 23.27, 14.97, 10.8, 16.78, 17.47, 14.97, 12.32, 13.43, 15.46, 11.08, 10.66, 8.671, 9.904, 16.46, 13.01, 12.81, 27.22, 21.09, 15.7, 11.41, 15.28, 10.08, 18.31, 11.71, 11.81, 12.3, 14.22, 12.77, 9.72, 12.34, 14.86, 12.91, 13.77, 18.08, 19.18, 14.45, 12.23, 17.54, 23.29, 13.81, 12.47, 15.12, 9.876, 17.01, 13.11, 15.27, 20.58, 11.84, 28.11, 17.42, 14.19, 13.86, 11.89, 10.2, 19.8, 19.53, 13.65, 13.56, 10.18, 15.75, 13.27, 14.34, 10.44, 15.0, 12.62, 12.83, 17.05, 11.32, 11.22, 20.51, 9.567, 14.03, 23.21, 20.48, 14.22, 17.46, 13.64, 12.42, 11.3, 13.75, 19.4, 10.48, 13.2, 12.89, 10.65, 11.52, 20.94, 11.5, 19.73, 17.3, 19.45, 13.96, 19.55, 15.32, 15.66, 15.53, 20.31, 17.35, 17.29, 15.61, 17.19, 20.73, 10.6, 13.59, 12.87, 10.71, 14.29, 11.29, 21.75, 9.742, 17.93, 11.89, 11.33, 18.81, 13.59, 13.85, 19.16, 11.74, 19.4, 16.24, 12.89, 12.58, 11.94, 12.89, 11.26, 11.37, 14.41, 14.96, 12.95, 11.85, 12.72, 13.77, 10.91, 11.76, 14.26, 10.51, 19.53, 12.46, 20.09, 10.49, 11.46, 11.6, 13.2, 9.0, 13.5, 13.05, 11.7, 14.61, 12.76, 11.54, 8.597, 12.49, 12.18, 18.22, 9.042, 12.43, 10.25, 20.16, 12.86, 20.34, 12.2, 12.67, 14.11, 12.03, 16.27, 16.26, 16.03, 12.98, 11.22, 11.25, 12.3, 17.06, 12.99, 18.77, 10.05, 23.51, 14.42, 9.606, 11.06, 19.68, 11.71, 10.26, 12.06, 14.76, 11.47, 11.95, 11.66, 15.75, 25.73, 15.08, 11.14, 12.56, 13.05, 13.87, 8.878, 9.436, 12.54, 13.3, 12.76, 16.5, 13.4, 20.44, 20.2, 12.21, 21.71, 22.01, 16.35, 15.19, 21.37, 20.64, 13.69, 16.17, 10.57, 13.46, 13.66, 11.08, 11.27, 11.04, 12.05, 12.39, 13.28, 14.6, 12.21, 13.88, 11.27, 19.55, 10.26, 8.734, 15.49, 21.61, 12.1, 14.06, 13.51, 12.8, 11.06, 11.8, 17.91, 11.93, 12.96, 12.94, 12.34, 10.94, 16.14, 12.85, 17.99, 12.27, 11.36, 11.04, 9.397, 14.99, 15.13, 11.89, 9.405, 15.5, 12.7, 11.16, 11.57, 14.69, 11.61, 13.66, 9.742, 10.03, 10.48, 10.8, 11.13, 12.72, 14.9, 12.4, 20.18, 18.82, 14.86, 13.98, 12.87, 14.04, 13.85, 14.02, 10.97, 17.27, 13.78, 10.57, 18.03, 11.99, 17.75, 14.8, 14.53, 21.1, 11.87, 19.59, 12.0, 14.53, 12.62, 13.38, 11.63, 13.21, 13.0, 9.755, 17.08, 27.42, 14.4, 11.6, 13.17, 13.24, 13.14, 9.668, 17.6, 11.62, 9.667, 12.04, 14.92, 12.27, 10.88, 12.83, 14.2, 13.9, 11.49, 16.25, 12.16, 13.9, 13.47, 13.7, 15.73, 12.45, 14.64, 19.44, 11.68, 16.69, 12.25, 17.85, 18.01, 12.46, 13.16, 14.87, 12.65, 12.47, 18.49, 20.59, 15.04, 13.82, 12.54, 23.09, 9.268, 9.676, 12.22, 11.06, 16.3, 15.46, 11.74, 14.81, 13.4, 14.58, 15.05, 11.34, 18.31, 19.89, 12.88, 12.75, 9.295, 24.63, 11.26, 13.71, 9.847, 8.571, 13.46, 12.34, 13.94, 12.07, 11.75, 11.67, 13.68, 20.47, 10.96, 20.55, 14.27, 11.69, 7.729, 7.691, 11.54, 14.47, 14.74, 13.21, 13.87, 13.62, 10.32, 10.26, 9.683, 10.82, 10.86, 11.13, 12.77, 9.333, 12.88, 10.29, 10.16, 9.423, 14.59, 11.51, 14.05, 11.2, 15.22, 20.92, 21.56, 20.13, 16.6, 20.6, 7.76},
73         }
74         for i := 0; i <= len(data1[0]); i++ {
75                 c.Logf("================== %d", i)
76                 data := [][]statmodel.Dtype{
77                         append([]statmodel.Dtype(nil), data0[0]...),
78                         append([]statmodel.Dtype(nil), data0[1]...),
79                 }
80                 for j := 0; j < i; j++ {
81                         if len(data[0]) <= j {
82                                 data[0] = append(data[0], data1[0][j])
83                                 data[1] = append(data[1], data1[1][j])
84                         } else {
85                                 data[0][j] = data1[0][j]
86                                 data[1][j] = data1[1][j]
87                         }
88                 }
89                 constants := make([]statmodel.Dtype, len(data[0]))
90                 for i := range constants {
91                         constants[i] = 1
92                 }
93                 data = append(data, constants)
94                 c.Logf("%v", data)
95
96                 dataset := statmodel.NewDataset(data, []string{"y", "x1", "C"})
97                 model, err := glm.NewGLM(dataset, "y", []string{"x1", "C"}, glmConfig)
98                 c.Assert(err, check.IsNil)
99                 result := model.Fit()
100                 c.Logf("%s", result.Summary())
101                 c.Logf("%v", result.LogLike())
102
103                 pydata := "["
104                 for row, values := range data[:2] {
105                         if row > 0 {
106                                 pydata += ","
107                         }
108                         pydata += "\n    ["
109                         for col, v := range values {
110                                 if col > 0 {
111                                         pydata += ", "
112                                 }
113                                 pydata += fmt.Sprintf("%v", v)
114                         }
115                         pydata += "]"
116                 }
117                 pydata += "]"
118                 py := pyImports + `
119 data = np.array(` + pydata + `)
120 df = pd.DataFrame(data.T, columns=['y','x1'])
121 fit = smf.glm('y ~ x1', family=sm.families.Binomial(), data=df).fit()
122 print(fit.summary())
123 print(fit.llf)
124 `
125                 cmd := exec.Command("python3", "-")
126                 cmd.Stdin = strings.NewReader(py)
127                 out, err := cmd.CombinedOutput()
128                 c.Logf("%s", out)
129                 c.Assert(err, check.IsNil)
130                 outlines := bytes.Split(out, []byte{'\n'})
131                 llf, err := strconv.ParseFloat(string(outlines[len(outlines)-2]), 64)
132                 c.Assert(err, check.IsNil)
133                 c.Assert(math.Abs(result.LogLike()-llf) < 0.000000000001, check.Equals, true)
134         }
135 }
136
137 func (s *glmSuite) TestPvalueRealDataVsPython(c *check.C) {
138         checkVirtualenv(c)
139         samples, err := loadSampleInfo("glm_test_samples.csv")
140         if err != nil {
141                 c.Skip("test requires glm_test_samples.csv (not included)")
142         }
143         c.Logf("Nsamples = %d", len(samples))
144         nPCA := 5
145         // data series: y, rand, pca1, ..., pcaN
146         data := [][]statmodel.Dtype{nil, nil}
147         for i := 0; i < nPCA; i++ {
148                 data = append(data, []statmodel.Dtype(nil))
149         }
150         onehot := []bool{}
151         for _, si := range samples {
152                 if !si.isTraining {
153                         continue
154                 }
155                 if si.isCase {
156                         data[0] = append(data[0], 1)
157                 } else {
158                         data[0] = append(data[0], 0)
159                 }
160                 r := rand.Int()&1 == 1
161                 if rand.Int()&0x1f == 0 {
162                         // 1/32 samples have onehot==outcome, the rest
163                         // are random
164                         r = si.isCase
165                 }
166                 onehot = append(onehot, r)
167                 if r {
168                         data[1] = append(data[1], 1)
169                 } else {
170                         data[1] = append(data[1], 0)
171                 }
172                 for i := 0; i < nPCA; i++ {
173                         data[i+2] = append(data[i+2], si.pcaComponents[i])
174                 }
175         }
176
177         pGo := glmPvalueFunc(samples, nPCA)(onehot)
178         c.Logf("pGo = %g", pGo)
179
180         var pydata bytes.Buffer
181         pydata.WriteString("[")
182         for row, values := range data {
183                 if row > 0 {
184                         pydata.WriteString(",")
185                 }
186                 pydata.WriteString("\n    [")
187                 for col, v := range values {
188                         if col > 0 {
189                                 pydata.WriteString(", ")
190                         }
191                         fmt.Fprintf(&pydata, "%v", v)
192                 }
193                 pydata.WriteString("]")
194         }
195         pydata.WriteString("]")
196         py := pyImports + `
197 data = np.array(` + pydata.String() + `)
198 columns = ['y','onehot']
199 formula = ''
200 for i in range(` + fmt.Sprintf("%d", nPCA) + `):
201     columns.append('x'+str(i+1))
202     if len(formula) > 0:
203         formula += ' + '
204     formula += 'x'+str(i+1)
205 df = pd.DataFrame(data.T, columns=columns)
206
207 mod1 = smf.glm('y ~ '+formula, family=sm.families.Binomial(), data=df).fit()
208 # print(mod1.summary())
209 print('mod1.llf = ', mod1.llf)
210
211 mod2 = smf.glm('y ~ onehot + '+formula, family=sm.families.Binomial(), data=df).fit()
212 # print(mod2.summary())
213 print('mod2.llf = ', mod2.llf)
214
215 df = 1
216 p = 1 - scipy.stats.chi2.cdf(-2 * (mod1.llf - mod2.llf), df)
217 print(p)
218 `
219         c.Logf("python...")
220         cmd := exec.Command("python3", "-")
221         cmd.Stdin = strings.NewReader(py)
222         out, err := cmd.CombinedOutput()
223         c.Logf("%s", out)
224         c.Assert(err, check.IsNil)
225         outlines := bytes.Split(out, []byte{'\n'})
226         pPy, err := strconv.ParseFloat(string(outlines[len(outlines)-2]), 64)
227         c.Assert(err, check.IsNil)
228         c.Logf("pPy = %g", pPy)
229         c.Assert(math.Abs(pGo-pPy) < 0.000001, check.Equals, true)
230 }
231
232 func (s *glmSuite) TestPvalue(c *check.C) {
233         // csv: casecontrol,onehot,pca1,pca2,...
234         csv2test := func(csv string) (samples []sampleInfo, onehot []bool, npca int) {
235                 for _, line := range strings.Split(csv, "\n") {
236                         if len(line) == 0 || line[0] == '#' {
237                                 continue
238                         }
239                         fields := strings.Split(line, ",")
240                         var pca []float64
241                         for _, s := range fields[2:] {
242                                 f, err := strconv.ParseFloat(strings.TrimSpace(s), 64)
243                                 c.Assert(err, check.IsNil)
244                                 pca = append(pca, f)
245                         }
246                         isCase := strings.TrimSpace(fields[0]) == "1"
247                         samples = append(samples, sampleInfo{
248                                 isCase:        isCase,
249                                 isControl:     !isCase,
250                                 isTraining:    true,
251                                 pcaComponents: pca,
252                         })
253                         onehot = append(onehot, strings.TrimSpace(fields[1]) == "1")
254                         if rand.Int()%5 == 0 {
255                                 samples = append(samples, sampleInfo{
256                                         isCase:        rand.Int()%2 == 0,
257                                         isValidation:  true,
258                                         pcaComponents: pca,
259                                 })
260                         }
261                 }
262                 npca = len(samples[0].pcaComponents)
263                 return
264         }
265
266         samples, onehot, npca := csv2test(`
267 # case=1, onehot=1, pca1, pca2, pca3
268 0, 0, 1, 1.21, 2.37
269 0, 0, 2, 1.22, 2.38
270 0, 0, 3, 1.23, 2.39
271 0, 0, 1, 1.24, 2.33
272 0, 0, 2, 1.25, 2.34
273 1, 1, 3, 1.26, 2.35
274 1, 1, 1, 1.23, 2.36
275 1, 1, 2, 1.22, 2.32
276 1, 1, 3, 1.21, 2.31
277 `)
278         c.Check(glmPvalueFunc(samples, npca)(onehot), check.Equals, 0.002789665435066107)
279
280         samples, onehot, npca = csv2test(`
281 # case=1, onehot=1, pca1, pca2, pca3
282 0, 1, 1, 1.21, 2.37
283 0, 1, 2, 1.22, 2.38
284 0, 1, 3, 1.23, 2.39
285 0, 1, 1, 1.24, 2.33
286 0, 1, 2, 1.25, 2.34
287 1, 1, 3, 1.26, 2.35
288 1, 1, 1, 1.23, 2.36
289 1, 1, 2, 1.22, 2.32
290 1, 1, 3, 1.21, 2.31
291 `)
292         c.Check(math.IsNaN(glmPvalueFunc(samples, npca)(onehot)), check.Equals, true)
293 }
294
295 var benchSamples, benchOnehot = func() ([]sampleInfo, []bool) {
296         pcaComponents := 10
297         samples := []sampleInfo{}
298         onehot := []bool{}
299         r := make([]float64, pcaComponents)
300         for j := 0; j < 10000; j++ {
301                 for i := 0; i < len(r); i++ {
302                         r[i] = rand.Float64()
303                 }
304                 samples = append(samples, sampleInfo{
305                         id:            fmt.Sprintf("sample%d", j),
306                         isCase:        j%2 == 0 && j > 200,
307                         isControl:     j%2 == 1 || j <= 200,
308                         isTraining:    true,
309                         pcaComponents: append([]float64(nil), r...),
310                 })
311                 onehot = append(onehot, j%2 == 0)
312         }
313         return samples, onehot
314 }()
315
316 func (s *glmSuite) BenchmarkPvalue(c *check.C) {
317         for i := 0; i < c.N; i++ {
318                 p := glmPvalueFunc(benchSamples, len(benchSamples[0].pcaComponents))(benchOnehot)
319                 c.Check(p, check.Equals, 0.0)
320         }
321 }