Fix some tests.
[lightning.git] / chisquare.go
index 31142a1710686d2763f794cff54d0a27debcd3fa..842d0bb13f9fe1b67a30ef629c33892954dfbbc5 100644 (file)
@@ -5,41 +5,39 @@
 package lightning
 
 import (
-       "golang.org/x/exp/rand"
        "gonum.org/v1/gonum/stat/distuv"
 )
 
-var chisquared = distuv.ChiSquared{K: 1, Src: rand.NewSource(rand.Uint64())}
+var chisquared = distuv.ChiSquared{K: 1}
 
-func pvalue(a, b []bool) float64 {
-       //     !b        b
-       // !a  tab[0]    tab[1]
-       // a   tab[2]    tab[3]
-       tab := make([]int, 4)
-       for ai, aval := range []bool{false, true} {
-               for bi, bval := range []bool{false, true} {
-                       obs := 0
-                       for i := range a {
-                               if a[i] == aval && b[i] == bval {
-                                       obs++
-                               }
+func pvalue(x, y []bool) float64 {
+       var (
+               obs, exp [2]float64
+               sum      float64
+               sz       = float64(len(y))
+       )
+       for i, yi := range y {
+               if x[i] {
+                       if yi {
+                               obs[0]++
+                       } else {
+                               obs[1]++
                        }
-                       tab[ai*2+bi] = obs
                }
-       }
-       var sum float64
-       for ai := 0; ai < 2; ai++ {
-               for bi := 0; bi < 2; bi++ {
-                       rowtotal := tab[ai*2] + tab[ai*2+1]
-                       coltotal := tab[bi] + tab[2+bi]
-                       if rowtotal == 0 || coltotal == 0 {
-                               return 1
-                       }
-                       exp := float64(rowtotal) * float64(coltotal) / float64(len(a))
-                       obs := tab[ai*2+bi]
-                       d := float64(obs) - exp
-                       sum += (d * d) / exp
+               if yi {
+                       exp[0]++
+               } else {
+                       exp[1]++
                }
        }
-       return 1 - chisquared.CDF(sum)
+       if exp[0] == 0 || exp[1] == 0 || obs[0]+obs[1] == 0 {
+               return 1
+       }
+       exp[0] = (obs[0] + obs[1]) * exp[0] / sz
+       exp[1] = (obs[0] + obs[1]) * exp[1] / sz
+       for i := range exp {
+               d := obs[i] - exp[i]
+               sum += d * d / exp[i]
+       }
+       return chisquared.Survival(sum)
 }