1 """For storing and retrieving data (contingency tables)
2
3 @var _version: Version of this module
4 @type _version: String
5 """
6
7 _version = '$Id: Data.py,v 1.2 2008/10/07 08:57:06 jc Exp jc $'
8
9
10 from Variables import SubDomain, extdiv, Domain
11 from math import log
12 from Parameters import Factor, CPT
13 import operator
14
15 -class Data2(SubDomain):
16 """Attempt to implement ADTrees as a single flat dictionary"""
17
18 - def __init__(self,rawdata,domain=None,rmin=200):
19 """Initialise a from C{rawdata}
20 """
21 new_domain_variables, variables, records = rawdata[1:]
22 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True)
23 self._data = {}
24 self._make_dict(0,variables,records,rmin,self._data)
25
26
27 - def _make_dict(self,i,variables,records,rmin,data):
28 """
29 Each key is of form eg (None,1,None,2) indicates a
30 branch corresponding to variable1=1 and variable3=2
31 and all others summed out
32
33 Each value is a tuple of tuples. Each individual tuple is
34 the instantiation of the remaining variables and finally
35 the count
36
37 Note that each split has a None branch, corresponding
38 to all values for the variable given by that depth
39 """
40
41
42
43
44 if i >= len(variables) or len(records) < rmin:
45 tmp = []
46 for record in records:
47 tmp.append(record[i:])
48 data[record[:i]] = tuple(tmp)
49
50 return
51
52
53
54 split = []
55 for val in self._domain[variables[i]]:
56 split.append([])
57 i_summed_out = {}
58 for record in records:
59 split[record[i]].append(record)
60
61 try:
62 i_summed_out[record[i+1:-1]] += record[-1]
63 except KeyError:
64 i_summed_out[record[i+1:-1]] = record[-1]
65
66 marginal_records = []
67 common_prefix = records[0][:i]+(None,)
68 for key, value in i_summed_out.items():
69 marginal_records.append(common_prefix+key+(value,))
70
71
72 best_so_far = 0
73 for j, list_of_records in enumerate(split):
74
75 if len(list_of_records) > best_so_far:
76 best_so_far = len(list_of_records)
77 mcv = j
78
79
80
81 for j, list_of_records in enumerate(split):
82 if j != mcv and len(list_of_records) > 0:
83 self._make_dict(i+1,variables,list_of_records,rmin,data)
84
85 self._make_dict(i+1,variables,marginal_records,rmin,data)
86
87 - def _test(self,variables):
88 variables = frozenset(variables)
89 lv = len(variables)
90 indx = []
91 for i, v in enumerate(sorted(self._variables)):
92 if v in variables:
93 indx.append((i,self._numvals[v]))
94 if len(indx) == lv:
95 break
96
97 template = [None] * len(self._variables)
98 for inst in self.insts_indices(sorted(variables)):
99
100 for j in range(len(inst)):
101 temp = template[:]
102 for i, ins in enumerate(inst):
103 if i != j:
104 temp[indx[i][0]] = ins
105 query = tuple(temp)
106 print query,
107 for l in range(len(template)+1):
108 if query[:l] in self._data:
109 print query[:l], 'OK'
110 break
111
112
114 """Return dict mapping non-zero insts (as indices) to values
115 """
116 variables = frozenset(variables)
117 lv = len(variables)
118 indx = []
119 for i, v in enumerate(sorted(self._variables)):
120 if v in variables:
121 indx.append((i,self._numvals[v]))
122 if len(indx) == lv:
123 break
124 dkt = {}
125 self._marginal((),indx,dkt)
126 return dkt
127
129 """Return a dictionary mapping insts of C{variables} to
130 their values (where non-zero)
131
132 counts are conditional on the inst C{branch}
133 """
134
135
136
137
138 branch_length = len(branch)
139 nvars = len(self._variables)
140 if not vindices:
141 branch += tuple([None]*(nvars-branch_length))
142 branch_length = len(branch)
143 assert branch_length == nvars
144 for i in range(branch_length+1):
145 if branch[:i] in self._data:
146 if not vindices:
147
148 branch = branch[:i]
149
150
151 vis = [v[0]-branch_length for v in vindices]
152 tmp = {}
153 for row in self._data[branch[:i]]:
154 key = tuple(row[k] for k in vis)
155 try:
156 tmp[key] += row[-1]
157 except KeyError:
158 tmp[key] = row[-1]
159 template = [None]*(len(row)-1)
160 for key, val in tmp.items():
161 for p, k in enumerate(vis):
162 template[k] = key[p]
163 assert len(branch+tuple(template)) == nvars
164 dkt[branch+tuple(template)] = val
165 print dkt
166 return
167
168
169 vi, numvals = vindices[0]
170
171 branch += tuple([None]*(vi-len(branch)))
172
173
174
175
176
177 dkts = []
178 for i in range(numvals):
179 tmp = {}
180 new_branch = branch+(i,)
181 check_branch = new_branch
182 while len(check_branch) < nvars:
183 if check_branch in self._data:
184
185 self._marginal(new_branch,vindices[1:],tmp)
186 break
187 else:
188 check_branch += (None,)
189 else:
190
191 mcv = i
192 self._marginal(branch+(None,),vindices[1:],tmp)
193 dkts.append(tmp)
194
195
196 tmp_mcv = dkts[mcv]
197 endbit = vi+1
198 for i, tmp in enumerate(dkts):
199 if i != mcv:
200 for key, val in tmp.items():
201 tmp_mcv[branch+(None,)+key[endbit:]] -= val
202
203 for key, val in tmp_mcv.items():
204 tmp_mcv[branch+(mcv,)+key[endbit:]] = val
205 del tmp_mcv[key]
206
207 for tmp in dkts:
208 dkt.update(tmp)
209
210
211 -class Data(SubDomain):
212 """Factors whose data is stored in a table in an sqlite database
213
214 There is one row in the table for each non-zero value. These are meant to
215 be used for factors with many variables, but not too many non-zero values:
216 contingency tables for example. This object is a sensible choice when most of the
217 information required from the data can be gleaned in a few passes over it.
218
219 At present all database are stored in RAM.
220
221 @cvar db: The common database for objects of this class
222 @type db: sqlite3.Connection object
223 @ivar table: The name of table containing the object's data. This is read-only
224 (it is defined by a 'property') and is always equal to: 'table%d' % id(self)
225 @type table: String
226 @ivar n: Number of datapoints in the data
227 @type n: Integer
228
229 """
230 try:
231 from pysqlite2 import dbapi2 as sqlite
232 db = sqlite.connect(':memory:')
233 cursor = db.cursor()
234 except ImportError:
235 print "Can't use Data class"
236
237
238
239
240
241 - def __init__(self,data=None,variables=(),
242 domain=None,new_domain_variables=None,
243 must_be_new=False,check=False,convert=False):
244 """Initialise a L{Data} object
245
246 @param data: If C{None}, then an empty L{Data} object is created.
247 If a file object, then assumed to be connected to a CSV file in the format
248 that L{IO.read_csv} can read.
249 If a string, assumed to be the name of a CSV file in the format
250 that L{IO.read_csv} can read.
251 If a tuple, assumed to be like one returned by L{IO.read_csv}.
252
253 Unless C{data} is None any values for C{variables} and
254 C{new_domain_variables} are ignored.
255 @type data: Tuple
256 @param variables: Variables in the data
257 @type variables: Sequence
258 @param new_domain_variables: A dictionary containing a mapping from any new
259 variables to their values.
260 @type new_domain_variables: Dict or None
261 @param domain: A domain for the model.
262 If None the internal default domain is used.
263 @type domain: L{Variables.Domain} or None
264 @param must_be_new: Whether domain variables in C{new_domain_variables} have
265 to be new
266 @type must_be_new: Boolean
267 @param check: Whether to check that
268 (1) C{variables} is of the right form, and (2) that each variable
269 has an associated set of values and (3) that C{data} is the right size and type.
270 @type check: Boolean
271 @param convert: If C{True}, C{data} is converted to a list.
272 @type convert: Boolean
273 @raise TypeError: If C{check} is set and C{convert} is not set and
274 C{data} is of the wrong type.
275 @raise VariableError: If C{check} is set and there is a variable in
276 C{variables} which does not have
277 associated values. Or If a variable in C{new_domain_variables}
278 already exists with values different from
279 its values in C{new_domain_variables};
280 Or if C{must_be_new} is set and the variable already exists.
281 """
282 from gPy.IO import read_csv
283 if data is not None:
284 data_type = type(data)
285 if data_type == tuple:
286 rawdata = data
287 elif data_type == str:
288 if data.endswith('.gz'):
289 import gzip
290 rawdata = read_csv(gzip.open(data))
291 else:
292 rawdata = read_csv(open(data))
293 elif data_type == file:
294 rawdata = read_csv(data)
295 else:
296 raise TypeError('data should be file, string, tuple or None')
297 new_domain_variables, variables, records = rawdata[1:]
298 SubDomain.__init__(self,variables,domain,new_domain_variables,must_be_new,check)
299 cols = ','.join([v + ' INT' for v in sorted(self._variables)]+['value INT'])
300 self.cursor.execute('CREATE TABLE %s (%s)' % (self.table,cols))
301 if data is not None:
302 self.populate(records)
303 self._n = self.total_count()
304 self._cached_tables = []
305
307 self.cursor.execute('DROP TABLE IF EXISTS %s' % self.table)
308
310 dkt = self.__dict__.copy()
311 self.cursor.execute('SELECT * FROM %s' % self.table)
312
313 dkt['_data'] = self.cursor.fetchall()
314 return dkt
315
317 """Iterates over those joint instantiations of the data which have
318 non-zero counts associated with them
319
320 On each iteration a tuple of non-negative integers is returned. The final
321 number is the count, the preceding numbers encode the joint instantiation.
322 For example, (1,0,2,23) states that the instantiation (1,0,2) has occurred 23 times.
323 (1,0,2) is the first variable with its instantiation 1 (its 2nd), the second with
324 instantiation 0 (its 1st) and the third with instantiation 2 (its 3rd). Variables and
325 values are ordered lexicographically.
326
327 @return: Iterator over non-zero count instantiations
328 @rtype: Iterator
329 """
330 self.cursor.execute('SELECT * FROM %s' % self.table)
331 for row in self.cursor:
332 yield row
333
335 self.__dict__ = state
336 cols = ','.join(sorted(self._variables)+['value'])
337 self.cursor.execute('CREATE TABLE %s (%s)' % (self.table, cols))
338 sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (
339 self.table,cols,','.join(['?']*(len(self._variables)+1)))
340 self.cursor.executemany(sql, self._data)
341 del self._data
342
344 """Return the L{gPyC.lgh} score for entire data set
345
346 @param precision: BDe-like precision
347 @type precision: Float
348 @return: L{gPyC.lgh} score
349 @rtype: Float
350 """
351 from gPyC import lgh
352 self.cursor.execute('SELECT value FROM %s' % self.table)
353 data = [val[0] for val in self.cursor]
354 return lgh(data,precision/self.table_size())
355
357 """Return the marginal dataset containing only C{variables}
358
359 Returned Data object will have same domain as C{self}
360
361 Does not alter C{self}
362 @param variables: Variables in returned marginal table
363 @type variables: Iterable, e.g. list, tuple, set
364 @return: New marginal dataset
365 @rtype: L{Data} object
366 """
367 marginal = Data(None,variables,domain=self)
368 if variables:
369 cols = ','.join(sorted(marginal._variables))
370 sql = ('INSERT INTO %s SELECT %s, sum(value) FROM %s GROUP BY %s' %
371 (marginal.table,cols,self.table,cols))
372 else:
373 sql = ('INSERT INTO %s SELECT sum(value) FROM %s' %
374 (marginal.table,self.table))
375 marginal.cursor.execute(sql)
376 return marginal
377
378
380 """
381 Return the conditional entropy H(x|y) for variable sets C{x} and C{y}
382 using the empirical distribution given by the data
383 """
384 return self.entropy(frozenset(x)|frozenset(y)) - self.entropy(y)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
425 """Return the entropy of the marginal empirical distribution given by C{x}
426 and the data
427 """
428
429 try:
430 return self.entropy_cache[frozenset(x)]
431 except AttributeError:
432 pass
433 except KeyError:
434 pass
435
436
437
438 if not x:
439 return 0.0
440 sql = 'SELECT sum(value) FROM %s GROUP BY %s' % (self.table,','.join(sorted(x)))
441 self.cursor.execute(sql)
442 rows = self.cursor.fetchall()
443
444
445
446 if len(rows) == 1:
447 try:
448 self.entropy_cache[frozenset(x)] = 0.0
449 except AttributeError:
450 pass
451 return 0.0
452 h = sum(count * log(count) for count in (val[0] for val in rows))
453 n = self._n
454 entropy = log(n) - h/n
455 try:
456 self.entropy_cache[frozenset(x)] = entropy
457 except AttributeError:
458 pass
459 return entropy
460
461 - def _bic_search2(self,child,n,child_df,old_parents,further_parents,store,pa_lim,highest_llh):
462 if len(old_parents) > pa_lim:
463 return
464 from heapq import heappush, heappop
465 if further_parents:
466 new_parent = further_parents[0]
467 further_parents = further_parents[1:]
468
469 for new_parents in old_parents, old_parents|frozenset([new_parent]):
470
471
472 if len(new_parents) < 2:
473 self._bic_search2(child,n,child_df,new_parents,further_parents,store,pa_lim,highest_llh)
474 else:
475
476
477
478
479 lowest_penalty = log(n) * (child_df * self.table_size(new_parents)) / 2
480 best_possible = highest_llh - lowest_penalty
481 tmp = store[:]
482
483 score, scored_parents = heappop(tmp)
484 while -score > best_possible:
485 if scored_parents < new_parents:
486 break
487 score, scored_parents = heappop(tmp)
488 else:
489 self._bic_search2(child,n,child_df,new_parents,further_parents,store,pa_lim,highest_llh)
490 else:
491 dim = child_df * self.table_size(old_parents)
492 complexity_penalty = (log(n) * dim) / 2
493 bic_score = (-n * self.conditional_entropy([child],old_parents)
494 - complexity_penalty)
495 if not store:
496 store.append((-bic_score,old_parents))
497 return
498 tmp = store[:]
499
500 score, scored_parents = heappop(tmp)
501 while -score > bic_score:
502 if scored_parents < old_parents:
503 break
504 score, scored_parents = heappop(tmp)
505 else:
506 heappush(store,(-bic_score,old_parents))
507
508 - def _bic_search(self,child,n,child_df,lower,bic_lower,upper,upper_bound,store):
509 """
510 Compute the BIC score for every parent set for C{child} which is a proper
511 superset of C{lower} and a subset of the union of C{lower} and C{upper} and add it to the
512 dictionary C{store}. C{bound} is a bound on the
513 log-likelihood (modulo an additive constant) on any possible parentset
514 C{bic_lower} is the BIC score of C{lower}.
515 """
516
517 open_list = []
518 for v in upper:
519 parentset = lower | frozenset([v])
520 dim = child_df * self.table_size(parentset)
521 complexity_penalty = (log(n) * dim) / 2
522 best_possible = upper_bound - complexity_penalty
523 if best_possible > bic_lower:
524 bic_score = (-n * self.conditional_entropy([child],parentset)
525 - complexity_penalty)
526 if bic_score > bic_lower:
527 store[parentset] = bic_score
528 print 'Stored', parentset, bic_score
529 open_list.append((bic_score,parentset,v))
530 else:
531 open_list.append((bic_lower,parentset,v))
532 else:
533 print 'Pruned', parentset
534 open_list.sort()
535 for bic_score, parentset, v in open_list:
536 upper = upper - frozenset([v])
537 self._bic_search(child,n,child_df,parentset,bic_score,upper,upper_bound,store)
538
539
540
542 """Branch and bound search for all parent sets for C{child}
543 which do not have a higher scoring subset
544
545 TODO: use a random graph as an input
546 """
547 n = self._n
548 child_df = self._numvals[child] - 1
549 old_parents = frozenset()
550 further_parents = tuple(self._variables - frozenset([child]))
551 store = []
552 highest_llh = -n * self.conditional_entropy([child],further_parents)
553 self._bic_search2(child,n,child_df,old_parents,further_parents,store,pa_lim,highest_llh)
554 return dict((s[1],-s[0]) for s in store)
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
617 """The log-likelihood of C{adg} with MLE parameters (up to an
618 additive constant)
619
620 The missing constant is the log of the multinomial
621 coefficient which is the same for all adgs.
622
623 """
624 dkt = {}
625 for child in adg.vertices():
626 parents = frozenset(adg.parents(child))
627 family = parents | frozenset([child])
628
629 try:
630 dkt[parents] += 1
631 except KeyError:
632 dkt[parents] = 1
633
634 try:
635 dkt[family] -= 1
636 except KeyError:
637 dkt[family] = -1
638
639 for variableset, mutliplier in dkt.items():
640 if multiplier != 0:
641 score += multiplier * self.entropy(variableset)
642 return self.n * score
643
647
665
666
668 """Return the number of datapoints in the data
669
670 @rtype: The number of datapoints in the data
671 @return: Integer
672 """
673 self.cursor.execute('SELECT sum(value) FROM %s' % self.table)
674 return self.cursor.fetchone()[0]
675
676
677 - def populate(self,records,variables=None):
678 """Simply inserts the records into the database
679
680 Each record is a tuple of integers. Each integer, except the last, corresponds to a value.
681 The last value is the count. If a joint instantiation occurs more than once, only the
682 last count is used.
683
684 Assumes lexicographic order of variables if C{variables} is None, otherwise the
685 order given by C{variables}.
686
687 """
688 if variables is None:
689 variables = sorted(self._variables)
690 elif frozenset(variables) != self._variables:
691 raise ValueError('Got sent these variables: %s, expected these: %s'
692 % (variables, self._variables))
693 cols = ','.join(list(variables)+['value'])
694 sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (
695 self.table,cols,','.join(['?']*(len(self._variables)+1)))
696 self.cursor.executemany(sql, records)
697 self._n = self.total_count()
698
700 self.cursor.execute('SELECT value FROM %s' % self.table)
701 data = sorted([val[0] for val in self.cursor.fetchall()])
702 hs = [0] * data[-1]
703 m = len(data)
704 j = 0
705 for k, count in enumerate(data):
706 for i in range(j,count):
707 hs[i] = m - k
708 j = count
709 return hs
710
711 - def ub(self,qpa,alpha,ri):
712 """The upper bound self provides on a the score of a smaller
713 parent set where
714
715 @param qpa: Size of contingency table for smaller parent set
716 @param alpha: Effective sample size
717 @param ri: Number of values of the child
718 """
719 hs = self.qhs()
720 total = 0.0
721 a = alpha/float(qpa)
722 for h, qh in enumerate(hs):
723 total += qh*log((h+a/ri)/(h+a))
724
725 return (float(qpa)/self.table_size())*total
726
727
728
730 """Return the number of instantiations (ie cells) having a value
731 greater than C{h}
732
733 @param h: Threshold
734 @type h: Integer
735 """
736 self.cursor.execute('SELECT value FROM %s WHERE value > %d' % (self.table,h))
737 return len(self.cursor.fetchall())
738
740 """
741 Make scores for all parent sets for all variables where (1) the size
742 of the parent set is at most C{pa_size_lim}. No pruning!
743 """
744
745 raise RuntimeError("This method is buggy, don't use!")
746
747 from gPy.Parameters import Factor
748 from gPy.Utils import subsetn_batch
749 from array import array
750 from gPyC import lgh
751
752 precision = float(precision)
753 batch_size = min(batch_size,65536)
754
755 marginal_size_lim = pa_size_lim+1
756 cursor = self.db.cursor()
757
758 h_score = {}
759 pa_scores = {}
760 for v in self._variables:
761 pa_scores[v] = {}
762
763 if marginal_size_lim > len(self._variables)/2:
764 raise ValueError('pa_size %d too big (see subseteqn_batch docs)' % pa_size)
765
766 faset_batch_all = []
767 def tmpfn(x): return len(x) == marginal_size_lim
768
769 variable_indices = range(len(self._variables))
770 last_variable_index = variable_indices[-1]
771
772 for faset_batch_all_tmp in subsetn_batch(variable_indices,marginal_size_lim,batch_size):
773
774
775
776
777 faset_batch_all.extend(faset_batch_all_tmp)
778
779
780
781 faset_batch = filter(tmpfn,faset_batch_all)
782
783
784
785
786
787
788 fasets_i_including = [array('H') for var in self._variables]
789 mults = [{} for var in self._variables]
790 current = ['0'] * len(self._variables)
791 num_fasets = len(faset_batch)
792 count = [None] * num_fasets
793
794 indx = array('I',[0] * num_fasets)
795
796
797
798 for faset_i, faset in enumerate(faset_batch):
799 size = 1
800 for j in sorted(faset,reverse=True):
801 fasets_i_including[j].append(faset_i)
802 mults[j][faset_i] = size
803 size *= self._numvals[self._sortedvariables[j]]
804
805 count[faset_i] = array('H',[0] * size)
806
807
808
809
810
811 cursor.execute('SELECT * FROM %s' % self.table)
812 for val, valcount in cursor.fetchall():
813
814
815
816 previous = current
817 current = val.split(',')
818 for j, inst_j in enumerate(current):
819 if inst_j == previous[j]:
820 continue
821 mult_j = mults[j]
822 diff_j = int(inst_j) - int(previous[j])
823 for faset_i in fasets_i_including[j]:
824 indx[faset_i] += (diff_j * mult_j[faset_i])
825
826
827
828 for faset_i, count_faset in enumerate(count):
829 count_faset[indx[faset_i]] += valcount
830
831
832
833
834
835
836
837
838
839 k = 0
840 for faset_i, faset in enumerate(faset_batch):
841 subsets = []
842 while True:
843 subset = faset_batch_all[k]
844 k += 1
845 subsets.append(subset)
846 if subset == faset:
847 break
848
849
850 if subsets[-1][-1] == last_variable_index:
851 biggest = frozenset(subsets[-1])
852 while k < len(faset_batch_all):
853 next_subset = faset_batch_all[k]
854 if not biggest.issuperset(next_subset):
855 break
856 subsets.append(next_subset)
857 k += 1
858
859 factor_data = list(count[faset_i])
860 factor_variables = [self._sortedvariables[j] for j in faset]
861 factor = Factor(factor_variables,factor_data,self)
862 for marginal in subsets:
863 marginal_variables = frozenset([self._sortedvariables[j] for j in marginal])
864 marginal_data = (factor._data_marginalise(
865 factor_data,
866 factor_variables,
867 factor._variables - marginal_variables))
868
869
870 h_score[marginal_variables] = lgh(marginal_data,precision/len(marginal_data))
871
872 faset_batch_all = faset_batch_all[k:]
873
874 for family, family_score in h_score.items():
875 for child in family:
876 parents = family-frozenset([child])
877 pa_scores[child][parents] = family_score - h_score[parents]
878
879
880
881
882
883
884
885
886 return pa_scores
887
889 """Yield counts for all marginals with C{n} variables in blocks of C{block}
890
891 Marginals are ordered according to how they are generated by the generator
892 L{Utils.subseteqn}.
893
894 @param n: The number of variables in the marginals
895 @type n: Int
896 @return: C{counts} where C{counts[i][idx]} is the count for the C{idx}th instantiation of
897 the C{i}th marginal
898 @rtype: List
899 """
900
901 from gPy.Utils import subseteqn
902
903 marginals_including = [[] for var in self._variables]
904 count = [None] * block
905 mults = [{} for var in self._variables]
906 for i, marginal in enumerate(subseteqn(range(len(self._variables)),n)):
907 modi = i % block
908 size = 1
909 for j in sorted(marginal,reverse=True):
910 marginals_including[j].append(modi)
911 mults[j][modi] = size
912 size *= self._numvals[self._sortedvariables[j]]
913 count[modi] = [0] * size
914 if block - modi == 1:
915 print 'OK i is', i
916 yield self._countsfromdata(count,mults,marginals_including)
917 marginals_including = [[] for var in self._variables]
918 count = [None] * block
919 mults = [{} for var in self._variables]
920
921 count = count[:modi+1]
922 yield self._countsfromdata(count,mults,marginals_including)
923
925 indx_template = [0] * len(count)
926 cursor = self._data.cursor()
927 cursor.execute('SELECT * FROM %s' % self.table)
928 for val, valcount in cursor.fetchall():
929 indx = indx_template[:]
930 for j, inst_j in enumerate(eval(val)):
931 mult_j = mults[j]
932 for i in marginals_including[j]:
933 indx[i] += (inst_j * mult_j[i])
934 for i, idx in enumerate(indx):
935 count[i][idx] += valcount
936 return count
937
939 """Return counts for all marginals with C{n} variables
940
941 Marginals are ordered according to how they are generated by the generator
942 L{Utils.subseteqn}.
943
944 @param n: The number of variables in the marginals
945 @type n: Int
946 @return: C{counts} where C{counts[i][idx]} is the count for the C{idx}th instantiation of
947 the C{i}th marginal
948 @rtype: List
949 """
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967 from gPy.Utils import subseteqn
968
969 marginals_including = [[] for var in self._variables]
970 mults = [{} for var in self._variables]
971
972 count = []
973 for i, marginal in enumerate(subseteqn(range(len(self._variables)),n)):
974 size = 1
975 for j in sorted(marginal,reverse=True):
976 marginals_including[j].append(i)
977 mults[j][i] = size
978 size *= self._numvals[self._sortedvariables[j]]
979 print i, size
980 count.append([0] * size)
981
982 print 'this long', len(count)
983 indx_template = [0] * len(count)
984
985 cursor = self._data.cursor()
986 cursor.execute('SELECT * FROM %s' % self.table)
987 for val, valcount in cursor.fetchall():
988 indx = indx_template[:]
989 for (inst_j,mult_j,interval_j,marginals_including_j) in zip(
990 eval(val),mults,self._intervals,marginals_including):
991 for i in marginals_including_j:
992 indx[i] += (inst_j * mult_j[i])
993 for i, idx in enumerate(indx):
994 count[i][idx] += valcount
995 return count
996
999
1000
1009
1016
1017 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
1018 """
1019 @param prior: the Dirichlet prior parameter (the same parameter value
1020 is used for all instances!) Note there may be some problems with
1021 this method: a B{different} prior is used by the BDeu score. However,
1022 in practice, for parameter estimation, this prior method seems to be ok.
1023 I was lazy and it was simple to implement (cb). If prior is zero, then
1024 the parameters are the maximum likelihood estimation solutions.
1025 """
1026 family = set(parents) | set([child])
1027 f_child = self.makeFactor(family)
1028 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
1029
1030
1032 """Simple way to get a factor
1033
1034 @param variables: The variables in the required factor.
1035 If C{None}, then all of C{self}'s variables are used: which
1036 could produce a very large object!
1037 @type variables: Iterable
1038 @return: Marginal table
1039 @rtype: L{Factor} object
1040 """
1041 from gPy.Parameters import Factor
1042
1043 if variables is None:
1044 marginal = self
1045 else:
1046 marginal = self.marginal(variables)
1047 marginal.cursor.execute('SELECT * FROM %s' % marginal.table)
1048
1049
1050 st = 1
1051 step = []
1052 for v in sorted(variables,reverse=True):
1053 step.append(st)
1054 st *= self._numvals[v]
1055 step.reverse()
1056
1057 data = [0] * marginal.table_size()
1058 for row in self.cursor.fetchall():
1059 i = 0
1060 for j, val in enumerate(row[:-1]):
1061 i += val*step[j]
1062 data[i] = row[-1]
1063 return Factor(variables,data,domain=self)
1064
1065 @property
1067 """Return the current number of datapoints stored
1068
1069 @return: The current number of datapoints stored
1070 @rtype: Integer
1071 """
1072 return self._n
1073
1074
1075 @property
1077 """Return the name of the table storing C{self}'s data
1078
1079 @return: The name of the table storing C{self}'s data
1080 @rtype: String
1081 """
1082 return 'table%d' % id(self)
1083
1084
1085 __getitem__ = makeFactor
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191 -class CompactFactor(SubDomain):
1192 """Factors whose data is not explicity represented (because there's
1193 too much of it).
1194
1195 Generally used to store datasets. Implemented using ADTrees
1196 @ivar _tree: ADTree holding the data
1197 @type _tree: L{_ADTree} object
1198 @ivar _records: Each record is a tuple of integers. For each tuple the M{j}th element
1199 is the index of the value of the M{j}th variable found in that record.
1200 The final integer is a count of how often the record appeared.
1201 @type _records: List
1202 @ivar _var_index: TODO
1203 """
1204 - def __init__(self,data,domain=None,rmin=200):
1205 """Initialise a L{CompactFactor} from C{rawdata}
1206
1207 @param data: See L{Data} documentation
1208 @type data:
1209 @param domain: A domain for C{self}'s variables. If None
1210 the internal default domain is used.
1211 @type domain: Dictionary or None
1212 @param rmin: Controls the space/time tradeoff for using L{CompactFactor} objects.
1213 Subsets of records smaller than C{rmin} are stored as a tuple of (pointers to) the records
1214 concerned.
1215 @type rmin: int
1216 @raise KeyError: if C{rawdata} has a variable that has not been previously declared
1217 @raise TypeError: if C{rawdata} is not a sequence or its elements are not of
1218 the right type.
1219 @raise ValueError: if C{rawdata} does not have exactly 3 elements
1220 """
1221 from gPy.IO import read_csv
1222 if data is not None:
1223 data_type = type(data)
1224 if data_type == tuple:
1225 rawdata = data
1226 elif data_type == str:
1227 if data.endswith('.gz'):
1228 import gzip
1229 rawdata = read_csv(gzip.open(data))
1230 else:
1231 rawdata = read_csv(open(data))
1232 elif data_type == file:
1233 rawdata = read_csv(data)
1234 else:
1235 raise TypeError('data should be file, string, tuple or None')
1236 else:
1237 raise ValueError('Need some data actually!')
1238 new_domain_variables, variables, records = rawdata[1:]
1239 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True)
1240 var_index = {}
1241 numvals = []
1242 for i, variable in enumerate(variables):
1243 var_index[variable] = i
1244 numvals.append(self._numvals[variable])
1245 n = 0
1246 for record in records:
1247 n += record[-1]
1248 self._tree = _ADTree(numvals,records,0,rmin)
1249 self._records = records
1250 self._var_index = var_index
1251 self._n = n
1252
1254 """Branch and bound search for all parent sets for C{child}
1255 which do not have a higher scoring subset
1256
1257 TODO: use a random graph as an input
1258 """
1259 n = self._n
1260 child_df = self._numvals[child] - 1
1261 old_parents = frozenset()
1262 further_parents = tuple(self._variables - frozenset([child]))
1263 store = []
1264 self._bic_search2(child,n,child_df,old_parents,further_parents,store)
1265 return dict((s[1],-s[0]) for s in store)
1266
1267 - def _bic_search2(self,child,n,child_df,old_parents,further_parents,store):
1268 from heapq import heappush, heappop
1269 if further_parents:
1270 new_parent = further_parents[0]
1271 further_parents = further_parents[1:]
1272
1273 for new_parents in old_parents, old_parents|frozenset([new_parent]):
1274
1275
1276 if len(new_parents) < 2:
1277 self._bic_search2(child,n,child_df,new_parents,further_parents,store)
1278 else:
1279
1280
1281
1282 highest_llh = -n * self.conditional_entropy([child],new_parents.union(further_parents))
1283 lowest_penalty = log(n) * (child_df * self.table_size(new_parents)) / 2
1284 best_possible = highest_llh - lowest_penalty
1285 tmp = store[:]
1286
1287 score, scored_parents = heappop(tmp)
1288 while -score > best_possible:
1289 if scored_parents < new_parents:
1290 break
1291 score, scored_parents = heappop(tmp)
1292 else:
1293 self._bic_search2(child,n,child_df,new_parents,further_parents,store)
1294 else:
1295 dim = child_df * self.table_size(old_parents)
1296 complexity_penalty = (log(n) * dim) / 2
1297 bic_score = (-n * self.conditional_entropy([child],old_parents)
1298 - complexity_penalty)
1299 if not store:
1300 store.append((-bic_score,old_parents))
1301 return
1302 tmp = store[:]
1303
1304 score, scored_parents = heappop(tmp)
1305 while -score > bic_score:
1306 if scored_parents < old_parents:
1307 break
1308 score, scored_parents = heappop(tmp)
1309 else:
1310 heappush(store,(-bic_score,old_parents))
1311
1313 """
1314 Return the conditional entropy H(x|y) for variable sets C{x} and C{y}
1315 using the empirical distribution given by the data
1316 """
1317 return self.entropy(frozenset(x)|frozenset(y)) - self.entropy(y)
1318
1320 """Return the entropy of the marginal empirical distribution given by C{x}
1321 and the data
1322 """
1323 try:
1324 return self.entropy_cache[frozenset(x)]
1325 except AttributeError:
1326 pass
1327 except KeyError:
1328 pass
1329
1330
1331 if not x:
1332 return 0.0
1333
1334 h = sum(count * log(count) for count in self.get_nonzerocounts(x))
1335 n = self._n
1336 entropy = log(n) - h/n
1337 try:
1338 self.entropy_cache[frozenset(x)] = entropy
1339 except AttributeError:
1340 pass
1341 return entropy
1342
1343
1344
1346 return "Variables: %s\nTree:\n %s\n" % (
1347 sorted(self._variables), self._tree)
1348
1349 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
1350 """
1351 @param prior: the Dirichlet prior parameter (the same parameter value
1352 is used for all instances!) Note there may be some problems with
1353 this method: a B{different} prior is used by the BDeu score. However,
1354 in practice, for parameter estimation, this prior method seems to be ok.
1355 I was lazy and it was simple to implement (cb). If prior is zero, then
1356 the parameters are the maximum likelihood estimation solutions.
1357 """
1358 family = set(parents) | set([child])
1359 f_child = self.makeFactor(family)
1360 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
1361
1362
1364 variables_info = [
1365 [self._var_index[var],self._numvals[var]]
1366 for var in variables]
1367 variables_info.sort()
1368
1369 size = 1
1370 for variable_info in reversed(variables_info):
1371 variable_info.append(size)
1372 size *= variable_info[1]
1373 return self._tree._flatten(variables_info,0,no_zeroes=True)
1374
1376 """Return a marginal factor with C{variables}
1377
1378 Unless C{variables} is empty in which case return
1379 the sum of C{self}'s data
1380 @param variables: Variables to project onto
1381 @type variables: Iterable
1382 @return: The marginal factor
1383 @rtype: L{Factor} object
1384 @raise KeyError: if C{variables} contains a variable not contained in the
1385 L{CompactFactor}.
1386 """
1387 variables_info = [
1388 [self._var_index[var],self._numvals[var]]
1389 for var in variables]
1390 variables_info.sort()
1391
1392 size = 1
1393 for variable_info in reversed(variables_info):
1394 variable_info.append(size)
1395 size *= variable_info[1]
1396 return Factor(variables,self._tree._flatten(variables_info,0),self)
1397
1398 __getitem__ = makeFactor
1399
1401 """Return the number of nodes in the underlying ADTree
1402
1403 @return: The number of nodes in the tree
1404 @rtype: int
1405 """
1406 return self._tree.size()
1407
1410 """ADTree implementation
1411
1412 Ref::
1413
1414 @Article{moore98:_cached_suffic_statis_effic_machin,
1415 author = {Andrew Moore and Mary Soon Lee},
1416 title = {Cached Sufficient Statistics for Efficient Machine Learning with Large Datasets},
1417 journal = {Journal of Artificial Intelligence Research},
1418 year = 1998,
1419 volume = 8,
1420 pages = {67--91},
1421 url = {http://www.jair.org/media/453/live-453-1678-jair.pdf}
1422 }
1423
1424 @ivar _count: A count of the number of records 'in' the tree
1425 @type _count: int
1426 @ivar _data: may be
1427 .1 a tuple each element of which is either an _ADTree object or None,
1428 There is one element for each value of the variable corresponding to
1429 the top node of the tree. A None value indicates 0 records for the corresponding
1430 value
1431 .2 a tuple of records
1432 .3 a single _ADTree object. This is a space saving mechanism used when there
1433 is only one value of the variable with any records associated with it
1434 @type _data: Various
1435 @ivar _mcvindex: In cases 1) and 3) this is an integer stating which value has the most records.
1436 To this value *all* records are associated rather than just those with the appropriate
1437 value. In case 2) this is None
1438 @type _mcvindex: Various
1439 """
1440
1441 __slots__ = ('_data','_count','_mcvindex')
1442
1445
1448
1449
1450 - def __init__(self,numvals,records,depth,rmin):
1451 """Initialise an L{_ADTree} object
1452
1453 @param numvals: The ith element of this list is the number of values for the ith variable
1454 in the tree
1455 @type numvals: list
1456 @param records: The data as a list. Each element is a tuple of value indices, one for each variable,
1457 plus an extra count field as the final element.
1458 @type records: list
1459 @param depth: The depth of this tree within its containing CompactFactor. Also the index of the variable
1460 associated with the top node of this tree
1461 @type depth: int
1462 @param rmin: If the number of records is below C{rmin} then tree growing stops and C{records} is stored.
1463 Note that a single record may represent many datapoints since all records have an extra 'count' field.
1464 @type rmin: int
1465 """
1466
1467
1468 count = 0
1469
1470 for record in records:
1471 count += record[-1]
1472 self._count = count
1473
1474 if depth == len(numvals):
1475
1476 self._data = ()
1477 self._mcvindex = None
1478 return
1479
1480 if len(records) < rmin:
1481
1482 self._data = records
1483 self._mcvindex = None
1484 return
1485
1486
1487 tmp = []
1488 counts = []
1489 for value in range(numvals[depth]):
1490 tmp.append([])
1491 counts.append(0)
1492
1493
1494 for record in records:
1495 valindex = record[depth]
1496 tmp[valindex].append(record)
1497 counts[valindex] += record[-1]
1498
1499
1500 mcvcount = 0
1501 for valindex, count in enumerate(counts):
1502 if count > mcvcount:
1503 self._mcvindex, mcvcount = valindex, count
1504
1505
1506 tmp[self._mcvindex] = records
1507
1508
1509 data = []
1510 found_not_empty_before = False
1511
1512 for value_records in tmp:
1513 if value_records == []:
1514 data.append(None)
1515 else:
1516 data.append(_ADTree(numvals,value_records,depth+1,rmin))
1517 if found_not_empty_before:
1518 one_not_empty = False
1519 else:
1520 found_not_empty_before = True
1521 one_not_empty = True
1522
1523 if one_not_empty:
1524
1525 self._data = data[self._mcvindex]
1526 else:
1527
1528 self._data = tuple(data)
1529
1530 - def _flatten(self,variables_info,depth):
1531 """Return the data for a L{Factor}
1532
1533 The L{Factor}'s variables will generally be a subset of those
1534 for which data is stored in the tree
1535
1536 @param variables_info: Contains the necessary information on the variables
1537 sought without naming them. The ith element of C{variables_info} contains information on
1538 the ith variable sought. Each element of C{variables_info} is a 3 element list:
1539 variables_info[i][0] is the depth in the L{_ADTree} which deals with the ith variable sought.
1540 variables_info[i][1] is the number of values of the ith variable sought.
1541 variables_info[i][2] is the number of data values in the eventual factor which correspond
1542 to each value of the ith variable sought.
1543 (Clearly this depends on the number of values of 'later' variables.)
1544 @type variables_info: list
1545 @return: Values for the factor
1546 @rtype: list
1547 """
1548 if variables_info == []:
1549
1550 data = [self._count]
1551 elif self._mcvindex is None:
1552
1553
1554
1555 numvals = variables_info[0][1]
1556 size = variables_info[0][2]
1557 data = [0] * (numvals*size)
1558 for record in self._data:
1559 i = 0
1560 for variable_info in variables_info:
1561 i += record[variable_info[0]] * variable_info[2]
1562 data[i] += record[-1]
1563 elif variables_info[0][0] == depth:
1564
1565 numvals = variables_info[0][1]
1566 size = variables_info[0][2]
1567 mcvstart = self._mcvindex * size
1568 mcvend = mcvstart + size
1569 if isinstance(self._data,tuple):
1570
1571 data = []
1572 acc = [0] * size
1573 for i, branch in enumerate(self._data):
1574 if branch is None:
1575 data.extend([0] * size)
1576 else:
1577 this_data = branch._flatten(variables_info[1:],depth+1)
1578 data.extend(this_data)
1579 if i != self._mcvindex:
1580 acc = map(operator.add,acc,this_data)
1581
1582 data[mcvstart:mcvend] = map(operator.sub,data[mcvstart:mcvend],acc)
1583 else:
1584
1585 data = [0] * (size * numvals)
1586 data[mcvstart:mcvend] = self._data._flatten(variables_info[1:],depth+1)
1587 else:
1588
1589 if isinstance(self._data,tuple):
1590 branch = self._data[self._mcvindex]
1591 else:
1592 branch = self._data
1593 data = branch._flatten(variables_info,depth+1)
1594 return data
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1668 """Return the number of nodes in the tree
1669
1670 @return: The number of nodes in the tree
1671 @rtype: int
1672 """
1673 if self._mcvindex is None:
1674 return 1
1675 count = 0
1676 if isinstance(self._data,tuple):
1677 for branch in self._data:
1678 if branch is None:
1679 count += 1
1680 else:
1681 count += branch.size()
1682 else:
1683 count += self._data.size()
1684 return count
1685
1691
1694 """Given two sets of records, combing them according to the function C{how}
1695 @param destination: the target set of records where the result is stored
1696 and the left hand side of C{how}
1697 @type destination: list of tuples whose last element is a positive integer
1698 @param source: the right hand side of C{how}
1699 @type source: list of tuples whose last element is a positive integer
1700 @param how: operator to combine C{destination} and C{source}
1701 @type how: binary function
1702 """
1703
1704 destination = list(destination)
1705
1706 for src_rec in source:
1707 src_prefix = src_rec[:-1]
1708 for i, dst_rec in enumerate(destination):
1709 if src_prefix == dst_rec[:-1]:
1710 destination[i] = destination[i][:-1] + (how(destination[i][-1], src_rec[-1]),)
1711 break
1712 else:
1713 destination.append(src_rec)
1714 return destination
1715
1717 """For each child at depth C{depth} generate the list of
1718 records (C{children_records}) and the number of instances (C{counts}).
1719 @param numvals: The ith element of this list is the number of values for
1720 the ith variable in the tree
1721 @type numvals: list
1722 @param records: The data as a list. Each element is a tuple of value
1723 indices, one for each variable, plus an extra count field as the final
1724 element.
1725 @type records: iterable
1726 @param depth: The depth of this tree within its containing CompactFactor.
1727 Also the index of the variable associated with the top node of this tree
1728 @type depth: int
1729 @return: A 2-tuple: C{children_records} and C{counts}, as described above.
1730 @rtype: tuple
1731 """
1732 children_records = [[] for i in xrange(numvals)]
1733 counts = [0] * numvals
1734 for record in records:
1735 valindex = record[depth]
1736 children_records[valindex].append(record)
1737 counts[valindex] += record[-1]
1738
1739 return children_records, counts
1740
1742
1743 __slots__ = ('_data','_count','_mcvindex')
1744
1745 - def __init__(self, numvals, records, rmin, depth=0):
1753
1776
1777 - def str(self, numvals, rmin, n=0, mcv=True):
1778 prefix = str(n) + ' '*n
1779 s = prefix+'count: '+str(self._count)+', mcv:'+str(self._mcvindex)+'\n'
1780 if self._mcvindex is None:
1781
1782 s += prefix+'records\n'
1783 s += prefix+str(self._data)+'\n'
1784 elif not isinstance(self._data,tuple):
1785
1786 s += prefix+'*'+str(self._mcvindex)+': singleton\n'
1787 if mcv:
1788 data = self._build_mcv(numvals, rmin, n+1)
1789 else:
1790 data = self._data
1791 if data is None:
1792 s += prefix+' '+'empty\n'
1793 else:
1794 s += data.str(numvals,rmin,n+1,mcv)
1795 else:
1796
1797 for i, data in enumerate(self._data):
1798 if i == self._mcvindex:
1799 s += prefix+'*'+str(i)+':\n'
1800 if mcv:
1801 data = self._build_mcv(numvals,rmin,n+1)
1802 else:
1803 s += prefix+' '+str(i)+':\n'
1804 if data is None:
1805 s += prefix+' '+'empty\n'
1806 else:
1807 s += data.str(numvals,rmin,n+1,mcv)
1808 return s
1809
1830
1832 """Where possible convert self._data from a tuple (or list) into
1833 a singleton. other_than_mcv is a hint. If other than None, it indicates
1834 whether some child other than the MCV has a non-zero count."""
1835 if self._mcvindex is None:
1836 return
1837
1838 if not isinstance(self._data,tuple) and not isinstance(self._data,list):
1839 return
1840
1841 if other_than_mcv is None:
1842 for i, child in enumerate(self._data):
1843 if child is None or i == self._mcvindex:
1844 continue
1845 other_than_mcv = True
1846 break
1847 else:
1848 other_than_mcv = False
1849 elif not other_than_mcv:
1850 for i, child in enumerate(self._data):
1851 if i == self._mcvindex:
1852 continue
1853 assert child is None
1854
1855 if not other_than_mcv:
1856 assert isinstance(self._data[self._mcvindex],_IncrementalADTree)
1857 self._data = self._data[self._mcvindex]
1858 elif isinstance(self._data,list):
1859 self._data = tuple(self._data)
1860
1861
1864
1865 - def _update(self, numvals, records, rmin, depth=0):
1866 """Add data to an L{_ADTree} object
1867
1868 @param numvals: The ith element of this list is the number of values
1869 for the ith variable in the tree
1870 @type numvals: list
1871 @param records: The data as a list. Each element is a tuple of value
1872 indices, one for each variable, plus an extra count field as the final
1873 element.
1874 @type records: list
1875 @param depth: The depth of this tree within its containing
1876 CompactFactor. Also the index of the variable associated with the top
1877 node of this tree
1878 @type depth: int
1879 @param rmin: If the number of records is below C{rmin} then tree
1880 growing stops and C{records} is stored. Note that a single record may
1881 represent many datapoints since all records have an extra 'count'
1882 field.
1883 @type rmin: int
1884 """
1885
1886 for record in records:
1887 assert len(record)-1 == len(numvals)
1888 assert isinstance(record,tuple)
1889 for element in record:
1890 assert isinstance(element,int)
1891
1892
1893 new_count = 0
1894
1895 for record in records:
1896
1897 new_count += record[-1]
1898
1899
1900 if new_count == 0:
1901 return
1902
1903 self._count += new_count
1904
1905 if depth == len(numvals):
1906
1907 self._data = []
1908 self._mcvindex = None
1909 return
1910
1911 if self._mcvindex is None:
1912
1913 merged_records = _merge_records(records, self._data, operator.add)
1914
1915
1916 self._make_new_node(numvals, merged_records, rmin, depth)
1917 return
1918
1919
1920 children_records, counts = _distribute_records(numvals[depth], depth, records)
1921 counts[self._mcvindex] = new_count
1922 children_records[self._mcvindex] = records
1923
1924
1925 if self._have_compact_children():
1926
1927
1928 for i in xrange(len(counts)):
1929 if i == self._mcvindex:
1930 continue
1931 if counts[i] > 0:
1932 other_than_mcv = True
1933 break
1934 else:
1935 other_than_mcv = False
1936
1937
1938 if not other_than_mcv:
1939 self._data._update(numvals,records,rmin,depth+1)
1940 return
1941
1942
1943 data = [None] * len(counts)
1944 data[self._mcvindex] = self._data
1945 self._data = data
1946 data = None
1947 else:
1948 self._data = list(self._data)
1949
1950
1951 for i, child_records in enumerate(children_records):
1952 child = self._data[i]
1953 if counts[i] == 0:
1954 continue
1955
1956 if child is not None:
1957 child._update(numvals, child_records, rmin, depth+1)
1958 else:
1959 self._data[i] = _IncrementalADTree(numvals, child_records, rmin, depth+1)
1960
1961 self._compact_children()
1962 if self._have_compact_children():
1963 assert self._count == self._data._count
1964 else:
1965 assert self._count == self._data[self._mcvindex]._count
1966
1967
1968
1969 self._correct_mcv(numvals, rmin, depth)
1970 self._verify(numvals,rmin,depth)
1971
1972 - def _make_new_node(self,numvals,records,rmin,depth, force_expand=False):
1973 """Turn C{self} into a fresh node from records
1974
1975 @param numvals: The ith element of this list is the number of values
1976 for the ith variable in the tree
1977 @type numvals: list
1978 @param records: The data as a list. Each element is a tuple of value
1979 indices, one for each variable, plus an extra count field as the final
1980 element.
1981 @type records: list
1982 @param depth: The depth of this tree within its containing
1983 CompactFactor. Also the index of the variable associated with the top
1984 node of this tree
1985 @type depth: int
1986 @param rmin: If the number of records is below C{rmin} then tree growing stops and C{records} is stored.
1987 Note that a single record may represent many datapoints since all records have an extra 'count' field.
1988 @type rmin: int
1989 """
1990
1991 for record in records:
1992 assert len(record)-1 == len(numvals)
1993 assert isinstance(record,tuple)
1994 for element in record:
1995 assert isinstance(element,int)
1996
1997
1998 self._count = 0
1999
2000 for record in records:
2001 self._count += record[-1]
2002
2003 if depth == len(numvals):
2004
2005 self._data = []
2006 self._mcvindex = None
2007 return
2008
2009 if not force_expand and len(records) < rmin:
2010
2011 self._data = records
2012 self._mcvindex = None
2013 return
2014
2015 children_records, counts = _distribute_records(numvals[depth], depth, records)
2016
2017
2018 mcvcount = -1
2019 for valindex, count in enumerate(counts):
2020 if count > mcvcount:
2021 self._mcvindex, mcvcount = valindex, count
2022
2023
2024 children_records[self._mcvindex] = records
2025 counts[self._mcvindex] = self._count
2026
2027
2028 self._data = []
2029 other_than_mcv = False
2030
2031 for i, child_records in enumerate(children_records):
2032 if counts[i] == 0:
2033 self._data.append(None)
2034 continue
2035
2036 other_than_mcv |= i != self._mcvindex
2037 self._data.append(_IncrementalADTree(numvals,child_records, rmin, depth+1))
2038
2039 self._compact_children(other_than_mcv)
2040 if self._have_compact_children():
2041 assert self._count == self._data._count
2042 else:
2043 assert self._count == self._data[self._mcvindex]._count
2044 self._verify(numvals,rmin,depth)
2045
2047 """self is a record carrying node but we want to treat it like
2048 something bigger. Expand just this level. This is a special
2049 case of _make_new_node."""
2050 assert self._mcvindex is None
2051
2052 self._make_new_node(numvals, self._data, rmin, depth, force_expand=True)
2053
2055 if self._mcvindex is None:
2056 return
2057
2058 if self._have_compact_children():
2059 self._data._correct_mcv(numvals,rmin,depth+1)
2060 return
2061
2062 assert self._count == self._data[self._mcvindex]._count
2063
2064
2065 cur_mcv_count = self._count
2066 max_mcv = self._mcvindex
2067 max_mcv_count = 0
2068 for i, child in enumerate(self._data):
2069 if child is None or i == self._mcvindex:
2070 continue
2071 cur_mcv_count -= child._count
2072 if child._count > max_mcv_count:
2073 max_mcv, max_mcv_count = i, child._count
2074
2075
2076 if cur_mcv_count < max_mcv_count:
2077
2078
2079
2080
2081
2082
2083 data = list(self._data)
2084 data[self._mcvindex] = self._build_mcv(numvals, rmin, depth+1)
2085 assert (cur_mcv_count == 0 and data[self._mcvindex] is None) or data[self._mcvindex]._count == cur_mcv_count
2086 data[max_mcv] = self._data[self._mcvindex]
2087 self._data = data
2088 data = None
2089 self._mcvindex = max_mcv
2090 assert self._count == self._data[self._mcvindex]._count
2091
2092 self._compact_children(cur_mcv_count != 0 or self._count > max_mcv_count)
2093 if self._have_compact_children():
2094 self._data._correct_mcv(numvals,rmin,depth+1)
2095 return
2096
2097
2098 for child in self._data:
2099 if child is not None:
2100 child._correct_mcv(numvals, rmin, depth+1)
2101
2141
2142 - def _subtract(self, numvals, rmin, depth, lesser):
2143
2144 if lesser is None:
2145 return
2146
2147
2148 assert self._count >= lesser._count
2149
2150
2151 if self._mcvindex is None and lesser._mcvindex is not None:
2152 self._expand_one_level(numvals, rmin, depth)
2153
2154
2155 self._count -= lesser._count
2156
2157
2158 if depth == len(numvals):
2159 return
2160
2161
2162 if self._mcvindex is None and lesser._mcvindex is None:
2163 self._data = _merge_records(self._data, lesser._data, operator.sub)
2164 return
2165
2166
2167 if lesser._mcvindex is None:
2168
2169
2170 lesser = lesser.copy()
2171 lesser._expand_one_level(numvals, rmin, depth)
2172
2173 self_single = not isinstance(self._data,tuple)
2174 less_single = not isinstance(lesser._data,tuple)
2175
2176
2177 if self_single and not less_single:
2178
2179
2180 lesser._correct_mcv(numvals, rmin, depth)
2181 less_single = not isinstance(lesser._data,tuple)
2182
2183 assert not self_single or less_single
2184
2185 if self_single and less_single:
2186 assert self._mcvindex == lesser._mcvindex
2187 self_vary = self._data
2188 less_vary = lesser._data
2189 assert self_vary is None or self_vary._count == self._count + lesser._count
2190 assert less_vary is None or less_vary._count == lesser._count
2191 elif less_single:
2192 self_vary = self._data[self._mcvindex]
2193 less_vary = lesser._data
2194 assert self_vary is None or self_vary._count == self._count + lesser._count
2195 assert less_vary is None or less_vary._count == lesser._count
2196 if self._mcvindex != lesser._mcvindex:
2197 b = lesser._build_mcv(numvals, rmin, depth+1)
2198 if b is not None or b._count != 0:
2199 assert self._data[lesser._mcvindex] is not None
2200 assert self._data[lesser._mcvindex]._count != 0
2201 if self._data[lesser._mcvindex]._count == b._count:
2202 self._data = list(self._data)
2203 self._data[lesser._mcvindex] = None
2204 self._compact_children()
2205 else:
2206 self._data[lesser._mcvindex]._subtract(numvals,rmin, depth+1,b)
2207 else:
2208 self_vary = self._data[self._mcvindex]
2209 less_vary = lesser._data[lesser._mcvindex]
2210 assert self_vary is None or self_vary._count == self._count + lesser._count
2211 assert less_vary is None or less_vary._count == lesser._count
2212
2213 self._data = list(self._data)
2214 for i, (a, b) in enumerate(zip(self._data, lesser._data)):
2215 if i == self._mcvindex:
2216 continue
2217
2218 if i == lesser._mcvindex:
2219 b = lesser._build_mcv(numvals, rmin, depth+1)
2220
2221 if b is None or b._count == 0:
2222 continue
2223
2224 assert a is not None and a._count != 0
2225 if a._count == b._count:
2226 self._data[i] = None
2227 else:
2228 a._subtract(numvals, rmin, depth+1, b)
2229
2230 self._compact_children()
2231
2232
2233 assert self_vary is None or self_vary._count == self._count + lesser._count
2234 assert less_vary is None or less_vary._count == lesser._count
2235
2236 if self_vary is not None:
2237 self_vary._subtract(numvals, rmin, depth+1, less_vary)
2238 else:
2239 assert less_vary._count == 0
2240
2241
2242 assert self_vary is None or self_vary._count == self._count
2243 assert less_vary is None or less_vary._count == lesser._count
2244
2245 - def _verify(self, numvals, rmin, depth=0):
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326 -class IncrementalCompactFactor(SubDomain):
2327
2328 - def __init__(self,rawdata,domain=None,rmin=200):
2329 new_domain_variables, variables, records = rawdata[1:]
2330 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True)
2331
2332
2333
2334 self._rmin = rmin
2335 var_index = {}
2336 numvals = []
2337 for i, variable in enumerate(variables):
2338 var_index[variable] = i
2339 numvals.append(self._numvals[variable])
2340
2341 self._varnumvals = numvals
2342 self._tree = _IncrementalADTree(numvals,records,self._rmin)
2343 self._var_index = var_index
2344 self._tree._verify(self._varnumvals, self._rmin)
2345
2346
2348 cpy = copy(self)
2349 cpy._tree = cpy._tree.copy()
2350 return cpy
2351
2354
2356 records = rawdata[3]
2357 self._tree._update(self._varnumvals, records, self._rmin)
2358 self._tree._verify(self._varnumvals, self._rmin)
2359
2361 return "Variables: %s\nTree:\n%s\n" % (
2362 sorted(self._variables), self._tree.str(self._varnumvals, self._rmin))
2363
2365 """Return a marginal factor with C{variables}
2366
2367 Unless C{variables} is empty in which case return
2368 the sum of C{self}'s data
2369 @param variables: Variables to project onto
2370 @type variables: Iterable
2371 @param check: Whether to bother with an initial check that every member
2372 of C{variables} is in the L{CompactFactor}. (If this check is omitted
2373 and there I{is} an extra variable, then a L{Factor} with the wrong
2374 number of values will be created.)
2375 @type check: Boolean
2376 @return: The marginal factor
2377 @rtype: L{Factor} object
2378 @raise KeyError: if C{variables} contains a variable not contained in the
2379 L{CompactFactor}.
2380 """
2381 variables_info = [
2382 [self._var_index[var],self._numvals[var]]
2383 for var in variables]
2384 variables_info.sort()
2385
2386 size = 1
2387 for variable_info in reversed(variables_info):
2388 variable_info.append(size)
2389 size *= variable_info[1]
2390 return Factor(variables,self._tree._flatten(variables_info,0),SubDomain.copy(self))
2391
2392 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
2393 """
2394 @param prior: the Dirichlet prior parameter (the same parameter value
2395 is used for all instances!) Note there may be some problems with
2396 this method: a B{different} prior is used by the BDeu score. However,
2397 in practice, for parameter estimation, this prior method seems to be ok.
2398 I was lazy and it was simple to implement (cb). If prior is zero, then
2399 the parameters are the maximum likelihood estimation solutions.
2400 """
2401 family = set(parents) | set([child])
2402 f_child = self.makeFactor(family)
2403 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
2404
2407
2409 """Return the number of nodes in the underlying ADTree
2410
2411 @return: The number of nodes in the tree
2412 @rtype: int
2413 """
2414 return self._tree._count
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492