Package gPy :: Module Data
[hide private]
[frames] | no frames]

Source Code for Module gPy.Data

   1  """For storing and retrieving data (contingency tables) 
   2   
   3  @var _version: Version of this module 
   4  @type _version: String 
   5  """ 
   6   
   7  _version = '$Id: Data.py,v 1.2 2008/10/07 08:57:06 jc Exp jc $' 
   8   
   9   
  10  from Variables import SubDomain, extdiv, Domain 
  11  from math import log 
  12  from Parameters import Factor, CPT 
  13  import operator 
14 15 -class Data2(SubDomain):
16 """Attempt to implement ADTrees as a single flat dictionary""" 17
18 - def __init__(self,rawdata,domain=None,rmin=200):
19 """Initialise a from C{rawdata} 20 """ 21 new_domain_variables, variables, records = rawdata[1:] 22 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True) 23 self._data = {} 24 self._make_dict(0,variables,records,rmin,self._data)
25 26
27 - def _make_dict(self,i,variables,records,rmin,data):
28 """ 29 Each key is of form eg (None,1,None,2) indicates a 30 branch corresponding to variable1=1 and variable3=2 31 and all others summed out 32 33 Each value is a tuple of tuples. Each individual tuple is 34 the instantiation of the remaining variables and finally 35 the count 36 37 Note that each split has a None branch, corresponding 38 to all values for the variable given by that depth 39 """ 40 #print 'Working on', records[0][:i], records 41 42 # if not too many records, then don't sum any counts 43 # just store each count with appropriate key 44 if i >= len(variables) or len(records) < rmin: 45 tmp = [] 46 for record in records: 47 tmp.append(record[i:]) 48 data[record[:i]] = tuple(tmp) 49 #data[record[:-1]] = record[-1] 50 return 51 52 # split records according to ith value 53 # and sum out ith value on the way 54 split = [] 55 for val in self._domain[variables[i]]: 56 split.append([]) 57 i_summed_out = {} 58 for record in records: 59 split[record[i]].append(record) 60 # record[:i] constant for all records 61 try: 62 i_summed_out[record[i+1:-1]] += record[-1] 63 except KeyError: 64 i_summed_out[record[i+1:-1]] = record[-1] 65 66 marginal_records = [] 67 common_prefix = records[0][:i]+(None,) 68 for key, value in i_summed_out.items(): 69 marginal_records.append(common_prefix+key+(value,)) 70 71 # find most common value 72 best_so_far = 0 73 for j, list_of_records in enumerate(split): 74 #print j, len(list_of_records) 75 if len(list_of_records) > best_so_far: 76 best_so_far = len(list_of_records) 77 mcv = j 78 79 # for all values apart from most common 80 # send associated records to recursive call for processing 81 for j, list_of_records in enumerate(split): 82 if j != mcv and len(list_of_records) > 0: 83 self._make_dict(i+1,variables,list_of_records,rmin,data) 84 85 self._make_dict(i+1,variables,marginal_records,rmin,data)
86
87 - def _test(self,variables):
88 variables = frozenset(variables) 89 lv = len(variables) 90 indx = [] 91 for i, v in enumerate(sorted(self._variables)): 92 if v in variables: 93 indx.append((i,self._numvals[v])) 94 if len(indx) == lv: 95 break 96 # find all matching keys: 97 template = [None] * len(self._variables) 98 for inst in self.insts_indices(sorted(variables)): 99 # make query 100 for j in range(len(inst)): 101 temp = template[:] 102 for i, ins in enumerate(inst): 103 if i != j: 104 temp[indx[i][0]] = ins 105 query = tuple(temp) 106 print query, 107 for l in range(len(template)+1): 108 if query[:l] in self._data: 109 print query[:l], 'OK' 110 break
111 112
113 - def marginal(self,variables):
114 """Return dict mapping non-zero insts (as indices) to values 115 """ 116 variables = frozenset(variables) 117 lv = len(variables) 118 indx = [] 119 for i, v in enumerate(sorted(self._variables)): 120 if v in variables: 121 indx.append((i,self._numvals[v])) 122 if len(indx) == lv: 123 break 124 dkt = {} 125 self._marginal((),indx,dkt) 126 return dkt
127
128 - def _marginal(self,branch,vindices,dkt):
129 """Return a dictionary mapping insts of C{variables} to 130 their values (where non-zero) 131 132 counts are conditional on the inst C{branch} 133 """ 134 135 136 # first check to see if this branch has been pruned 137 # (or exists exactly) 138 branch_length = len(branch) 139 nvars = len(self._variables) 140 if not vindices: 141 branch += tuple([None]*(nvars-branch_length)) 142 branch_length = len(branch) 143 assert branch_length == nvars 144 for i in range(branch_length+1): 145 if branch[:i] in self._data: 146 if not vindices: 147 # previously over-extended 148 branch = branch[:i] 149 150 # make marginal dictionary 151 vis = [v[0]-branch_length for v in vindices] 152 tmp = {} 153 for row in self._data[branch[:i]]: 154 key = tuple(row[k] for k in vis) 155 try: 156 tmp[key] += row[-1] 157 except KeyError: 158 tmp[key] = row[-1] 159 template = [None]*(len(row)-1) 160 for key, val in tmp.items(): 161 for p, k in enumerate(vis): 162 template[k] = key[p] 163 assert len(branch+tuple(template)) == nvars 164 dkt[branch+tuple(template)] = val 165 print dkt 166 return 167 168 169 vi, numvals = vindices[0] 170 # summing out all variables up to vi 171 branch += tuple([None]*(vi-len(branch))) 172 173 174 # for each i have to artificially extend 175 # new_branch to check it's there, due 176 # to artificial nature of flattened tree 177 dkts = [] 178 for i in range(numvals): 179 tmp = {} 180 new_branch = branch+(i,) 181 check_branch = new_branch 182 while len(check_branch) < nvars: 183 if check_branch in self._data: 184 # so not the mcv, so do normal recursive call 185 self._marginal(new_branch,vindices[1:],tmp) 186 break 187 else: 188 check_branch += (None,) 189 else: 190 # missing branch for i 191 mcv = i 192 self._marginal(branch+(None,),vindices[1:],tmp) 193 dkts.append(tmp) 194 195 #correct dkts[mcv] by subtraction 196 tmp_mcv = dkts[mcv] 197 endbit = vi+1 198 for i, tmp in enumerate(dkts): 199 if i != mcv: 200 for key, val in tmp.items(): 201 tmp_mcv[branch+(None,)+key[endbit:]] -= val 202 # fix keys in dkts[mcv] 203 for key, val in tmp_mcv.items(): 204 tmp_mcv[branch+(mcv,)+key[endbit:]] = val 205 del tmp_mcv[key] 206 # put stuff in main dictionary 207 for tmp in dkts: 208 dkt.update(tmp)
209
210 211 -class Data(SubDomain):
212 """Factors whose data is stored in a table in an sqlite database 213 214 There is one row in the table for each non-zero value. These are meant to 215 be used for factors with many variables, but not too many non-zero values: 216 contingency tables for example. This object is a sensible choice when most of the 217 information required from the data can be gleaned in a few passes over it. 218 219 At present all database are stored in RAM. 220 221 @cvar db: The common database for objects of this class 222 @type db: sqlite3.Connection object 223 @ivar table: The name of table containing the object's data. This is read-only 224 (it is defined by a 'property') and is always equal to: 'table%d' % id(self) 225 @type table: String 226 @ivar n: Number of datapoints in the data 227 @type n: Integer 228 229 """ 230 try: 231 from pysqlite2 import dbapi2 as sqlite 232 db = sqlite.connect(':memory:') 233 cursor = db.cursor() 234 except ImportError: 235 print "Can't use Data class" 236 # missing file at present in slackware! 237 # import sqlite3 238 239 240
241 - def __init__(self,data=None,variables=(), 242 domain=None,new_domain_variables=None, 243 must_be_new=False,check=False,convert=False):
244 """Initialise a L{Data} object 245 246 @param data: If C{None}, then an empty L{Data} object is created. 247 If a file object, then assumed to be connected to a CSV file in the format 248 that L{IO.read_csv} can read. 249 If a string, assumed to be the name of a CSV file in the format 250 that L{IO.read_csv} can read. 251 If a tuple, assumed to be like one returned by L{IO.read_csv}. 252 253 Unless C{data} is None any values for C{variables} and 254 C{new_domain_variables} are ignored. 255 @type data: Tuple 256 @param variables: Variables in the data 257 @type variables: Sequence 258 @param new_domain_variables: A dictionary containing a mapping from any new 259 variables to their values. 260 @type new_domain_variables: Dict or None 261 @param domain: A domain for the model. 262 If None the internal default domain is used. 263 @type domain: L{Variables.Domain} or None 264 @param must_be_new: Whether domain variables in C{new_domain_variables} have 265 to be new 266 @type must_be_new: Boolean 267 @param check: Whether to check that 268 (1) C{variables} is of the right form, and (2) that each variable 269 has an associated set of values and (3) that C{data} is the right size and type. 270 @type check: Boolean 271 @param convert: If C{True}, C{data} is converted to a list. 272 @type convert: Boolean 273 @raise TypeError: If C{check} is set and C{convert} is not set and 274 C{data} is of the wrong type. 275 @raise VariableError: If C{check} is set and there is a variable in 276 C{variables} which does not have 277 associated values. Or If a variable in C{new_domain_variables} 278 already exists with values different from 279 its values in C{new_domain_variables}; 280 Or if C{must_be_new} is set and the variable already exists. 281 """ 282 from gPy.IO import read_csv 283 if data is not None: 284 data_type = type(data) 285 if data_type == tuple: 286 rawdata = data 287 elif data_type == str: 288 if data.endswith('.gz'): 289 import gzip 290 rawdata = read_csv(gzip.open(data)) 291 else: 292 rawdata = read_csv(open(data)) 293 elif data_type == file: 294 rawdata = read_csv(data) 295 else: 296 raise TypeError('data should be file, string, tuple or None') 297 new_domain_variables, variables, records = rawdata[1:] 298 SubDomain.__init__(self,variables,domain,new_domain_variables,must_be_new,check) 299 cols = ','.join([v + ' INT' for v in sorted(self._variables)]+['value INT']) 300 self.cursor.execute('CREATE TABLE %s (%s)' % (self.table,cols)) 301 if data is not None: 302 self.populate(records) 303 self._n = self.total_count() 304 self._cached_tables = []
305
306 - def __del__(self):
307 self.cursor.execute('DROP TABLE IF EXISTS %s' % self.table)
308
309 - def __getstate__(self):
310 dkt = self.__dict__.copy() 311 self.cursor.execute('SELECT * FROM %s' % self.table) 312 # add extra attribute for the data 313 dkt['_data'] = self.cursor.fetchall() 314 return dkt
315
316 - def __iter__(self):
317 """Iterates over those joint instantiations of the data which have 318 non-zero counts associated with them 319 320 On each iteration a tuple of non-negative integers is returned. The final 321 number is the count, the preceding numbers encode the joint instantiation. 322 For example, (1,0,2,23) states that the instantiation (1,0,2) has occurred 23 times. 323 (1,0,2) is the first variable with its instantiation 1 (its 2nd), the second with 324 instantiation 0 (its 1st) and the third with instantiation 2 (its 3rd). Variables and 325 values are ordered lexicographically. 326 327 @return: Iterator over non-zero count instantiations 328 @rtype: Iterator 329 """ 330 self.cursor.execute('SELECT * FROM %s' % self.table) 331 for row in self.cursor: 332 yield row
333
334 - def __setstate__(self,state):
335 self.__dict__ = state 336 cols = ','.join(sorted(self._variables)+['value']) 337 self.cursor.execute('CREATE TABLE %s (%s)' % (self.table, cols)) 338 sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % ( 339 self.table,cols,','.join(['?']*(len(self._variables)+1))) 340 self.cursor.executemany(sql, self._data) 341 del self._data
342
343 - def h_score(self,precision=1.0):
344 """Return the L{gPyC.lgh} score for entire data set 345 346 @param precision: BDe-like precision 347 @type precision: Float 348 @return: L{gPyC.lgh} score 349 @rtype: Float 350 """ 351 from gPyC import lgh 352 self.cursor.execute('SELECT value FROM %s' % self.table) 353 data = [val[0] for val in self.cursor] 354 return lgh(data,precision/self.table_size())
355
356 - def marginal(self,variables):
357 """Return the marginal dataset containing only C{variables} 358 359 Returned Data object will have same domain as C{self} 360 361 Does not alter C{self} 362 @param variables: Variables in returned marginal table 363 @type variables: Iterable, e.g. list, tuple, set 364 @return: New marginal dataset 365 @rtype: L{Data} object 366 """ 367 marginal = Data(None,variables,domain=self) 368 if variables: 369 cols = ','.join(sorted(marginal._variables)) 370 sql = ('INSERT INTO %s SELECT %s, sum(value) FROM %s GROUP BY %s' % 371 (marginal.table,cols,self.table,cols)) 372 else: 373 sql = ('INSERT INTO %s SELECT sum(value) FROM %s' % 374 (marginal.table,self.table)) 375 marginal.cursor.execute(sql) 376 return marginal
377 378
379 - def conditional_entropy(self,x,y):
380 """ 381 Return the conditional entropy H(x|y) for variable sets C{x} and C{y} 382 using the empirical distribution given by the data 383 """ 384 return self.entropy(frozenset(x)|frozenset(y)) - self.entropy(y)
385 386 387 ## def get_rows(self,x): 388 ## """Equivalent to: 389 390 ## sql = 'SELECT sum(value) FROM %s GROUP BY %s' % (self.table,','.join(sorted(x))) 391 ## self.cursor.execute(sql) 392 ## return self.cursor.fetchall() 393 ## """ 394 ## xset = frozenset(x) 395 ## # smaller tables always before bigger ones 396 ## for cached_table in self._cached_tables: 397 ## if cached_table == xset: 398 ## self.cursor.execute('SELECT value FROM table%s' % '_'.join(cached_table)) 399 ## return self.cursor.fetchall() 400 ## if cached_table > xset: 401 ## cols = ','.join(sorted(xset)) 402 ## table_name = 'table%s' % '_'.join(xset) 403 ## sql.cursor.execute('CREATE TABLE %s (%s,value)' % (table_name cols)) 404 ## sql.cursor.execute('INSERT INTO %s SELECT %s, sum(value) FROM table%s GROUP BY %s' % 405 ## (table_name, cols, '_'.join(cached_table), cols)) 406 407 408 ## sql = 'SELECT sum(value) FROM table%s GROUP BY %s' % ('_'.join(cached_table),','.join(sorted(x))) 409 ## break 410 ## else: 411 ## sql = 'SELECT sum(value) FROM %s GROUP BY %s' % (self.table,','.join(sorted(x))) 412 ## l = len(xset) 413 ## for i, cached_table in enumerate(self._cached_tables): 414 ## if len(cached_table) >= l: 415 ## self._cached_tables.insert(i,xset) 416 ## break 417 ## else: 418 ## self._cached_tables.append(xset) 419 ## print 'CREATE TABLE table%s AS %s' % ('_'.join(xset),sql) 420 ## self.cursor.execute('CREATE TABLE table%s AS %s' % ('_'.join(xset),sql)) 421 ## self.cursor.execute('SELECT value FROM table%s' % '_'.join(xset)) 422 ## return self.cursor.fetchall() 423
424 - def entropy(self,x):
425 """Return the entropy of the marginal empirical distribution given by C{x} 426 and the data 427 """ 428 #print frozenset(x), 429 try: 430 return self.entropy_cache[frozenset(x)] 431 except AttributeError: 432 pass 433 except KeyError: 434 pass 435 #print 'UNcached' 436 437 # empty set of variables has zero entropy 438 if not x: 439 return 0.0 440 sql = 'SELECT sum(value) FROM %s GROUP BY %s' % (self.table,','.join(sorted(x))) 441 self.cursor.execute(sql) 442 rows = self.cursor.fetchall() 443 #rows = self.get_rows(x) 444 # if only one instantiation then entropy is zero 445 # this shortcut avoids numerical problems 446 if len(rows) == 1: 447 try: 448 self.entropy_cache[frozenset(x)] = 0.0 449 except AttributeError: 450 pass 451 return 0.0 452 h = sum(count * log(count) for count in (val[0] for val in rows)) 453 n = self._n 454 entropy = log(n) - h/n 455 try: 456 self.entropy_cache[frozenset(x)] = entropy 457 except AttributeError: 458 pass 459 return entropy
460
461 - def _bic_search2(self,child,n,child_df,old_parents,further_parents,store,pa_lim,highest_llh):
462 if len(old_parents) > pa_lim: 463 return 464 from heapq import heappush, heappop 465 if further_parents: 466 new_parent = further_parents[0] 467 further_parents = further_parents[1:] 468 469 for new_parents in old_parents, old_parents|frozenset([new_parent]): 470 471 # for small parent sets don't bother trying to prune 472 if len(new_parents) < 2: 473 self._bic_search2(child,n,child_df,new_parents,further_parents,store,pa_lim,highest_llh) 474 else: 475 # is it worth continuing the branch without/with the new parent? 476 # note that a new pruning opportunity may arise even if new_parents == old_parents 477 # since the reduction in further_parents may decrease highest_llh enough 478 #highest_llh = -n * self.conditional_entropy([child],new_parents.union(further_parents)) 479 lowest_penalty = log(n) * (child_df * self.table_size(new_parents)) / 2 480 best_possible = highest_llh - lowest_penalty 481 tmp = store[:] 482 # negation of the score is stored in the heap 483 score, scored_parents = heappop(tmp) 484 while -score > best_possible: 485 if scored_parents < new_parents: 486 break 487 score, scored_parents = heappop(tmp) 488 else: 489 self._bic_search2(child,n,child_df,new_parents,further_parents,store,pa_lim,highest_llh) 490 else: 491 dim = child_df * self.table_size(old_parents) 492 complexity_penalty = (log(n) * dim) / 2 493 bic_score = (-n * self.conditional_entropy([child],old_parents) 494 - complexity_penalty) 495 if not store: 496 store.append((-bic_score,old_parents)) 497 return 498 tmp = store[:] 499 # negation of the score is stored in the heap 500 score, scored_parents = heappop(tmp) 501 while -score > bic_score: 502 if scored_parents < old_parents: 503 break 504 score, scored_parents = heappop(tmp) 505 else: 506 heappush(store,(-bic_score,old_parents))
507
508 - def _bic_search(self,child,n,child_df,lower,bic_lower,upper,upper_bound,store):
509 """ 510 Compute the BIC score for every parent set for C{child} which is a proper 511 superset of C{lower} and a subset of the union of C{lower} and C{upper} and add it to the 512 dictionary C{store}. C{bound} is a bound on the 513 log-likelihood (modulo an additive constant) on any possible parentset 514 C{bic_lower} is the BIC score of C{lower}. 515 """ 516 # compute open list 517 open_list = [] 518 for v in upper: 519 parentset = lower | frozenset([v]) 520 dim = child_df * self.table_size(parentset) 521 complexity_penalty = (log(n) * dim) / 2 522 best_possible = upper_bound - complexity_penalty 523 if best_possible > bic_lower: 524 bic_score = (-n * self.conditional_entropy([child],parentset) 525 - complexity_penalty) 526 if bic_score > bic_lower: 527 store[parentset] = bic_score 528 print 'Stored', parentset, bic_score 529 open_list.append((bic_score,parentset,v)) 530 else: 531 open_list.append((bic_lower,parentset,v)) 532 else: 533 print 'Pruned', parentset 534 open_list.sort() 535 for bic_score, parentset, v in open_list: 536 upper = upper - frozenset([v]) 537 self._bic_search(child,n,child_df,parentset,bic_score,upper,upper_bound,store)
538 539 540
541 - def bic_search(self,child,pa_lim):
542 """Branch and bound search for all parent sets for C{child} 543 which do not have a higher scoring subset 544 545 TODO: use a random graph as an input 546 """ 547 n = self._n 548 child_df = self._numvals[child] - 1 549 old_parents = frozenset() 550 further_parents = tuple(self._variables - frozenset([child])) 551 store = [] 552 highest_llh = -n * self.conditional_entropy([child],further_parents) 553 self._bic_search2(child,n,child_df,old_parents,further_parents,store,pa_lim,highest_llh) 554 return dict((s[1],-s[0]) for s in store)
555 556 ## lower = frozenset() 557 ## child_singleton = frozenset([child]) 558 ## n = self._n 559 ## child_df = self._numvals[child] - 1 560 ## bic_lower = -n * self.entropy(child_singleton)-(0.5 * log(n) * child_df) 561 ## upper = frozenset(self._variables - child_singleton) 562 ## upper_bound = -n * self.conditional_entropy(child_singleton,upper) 563 ## print 'Upper bound', upper_bound 564 ## store = {lower:bic_lower} 565 ## self._bic_search(child,n,child_df,lower,bic_lower,upper,upper_bound,store) 566 ## return store 567 568 569 ## from gPy.Utils import subsets_ascending 570 ## potential_parents = self._variables - frozenset([child]) 571 ## n = self._n 572 ## upper_bound = -n * self.conditional_entropy([child],potential_parents) 573 ## #print 'Upper bound on log-likelihood component is', upper_bound 574 ## child_df = self.numvals(child) - 1 575 ## pruned = set() 576 ## bic_scores = {frozenset():(-n * self.entropy([child]) - (0.5 * log(n) * child_df))} 577 578 ## for parentset in subsets_ascending(potential_parents): 579 ## print 'Considering', parentset 580 581 ## # already done empty set 582 ## if not parentset: 583 ## continue 584 585 ## # if a superset of an already pruned parentset 586 ## # then don't consider 587 ## ok = True 588 ## for pruned_parentset in pruned: 589 ## if parentset > pruned_parentset: 590 ## ok = False 591 ## break 592 ## if not ok: 593 ## print 'Pruned' 594 ## continue 595 596 ## # 'size' method returns number of joint instantiations 597 ## # perhaps it needs a better name 598 ## dim = child_df * self.size(parentset) 599 ## best_possible = upper_bound - (0.5 * log(n) * dim) 600 ## best_subset_score_parentset = None 601 ## for done_parentset, score in bic_scores.items(): 602 ## if done_parentset < parentset: 603 ## if best_possible < score: 604 ## pruned.add(parentset) 605 ## break 606 ## elif best_subset_score_parentset is None or best_subset_score_parentset < score: 607 ## best_subset_score_parentset = score 608 ## else: 609 ## bic_score = (-n * self.conditional_entropy([child],parentset) 610 ## - (0.5 * log(n) * dim)) 611 ## if bic_score > best_subset_score_parentset: 612 ## bic_scores[parentset] = bic_score 613 ## print parentset, bic_score 614 ## return bic_scores 615
616 - def loglikelihood(self,adg):
617 """The log-likelihood of C{adg} with MLE parameters (up to an 618 additive constant) 619 620 The missing constant is the log of the multinomial 621 coefficient which is the same for all adgs. 622 623 """ 624 dkt = {} 625 for child in adg.vertices(): 626 parents = frozenset(adg.parents(child)) 627 family = parents | frozenset([child]) 628 629 try: 630 dkt[parents] += 1 631 except KeyError: 632 dkt[parents] = 1 633 634 try: 635 dkt[family] -= 1 636 except KeyError: 637 dkt[family] = -1 638 639 for variableset, mutliplier in dkt.items(): 640 if multiplier != 0: 641 score += multiplier * self.entropy(variableset) 642 return self.n * score
643
644 - def bic_complexity_penalty(self,adg):
645 return (log(self.n/2.0) * 646 sum((self.numvals(child) - 1) * self.table_size(adg.parents(child)) for child in adg.vertices()))
647
648 - def mutual_information(self,x,y):
649 """Return the mutual information between the variable sets 650 C{x} and C{y} in the empirical distribution 651 determined by the data 652 653 @param x: Variable set 654 @type x: Iterable, frozenset most efficient 655 @param y: Variable set 656 @type y: Iterable, frozenset most efficient 657 @return: The mutual information 658 @rtype: Float 659 @raise ValueError: If C{x} and C{y} are not disjoint 660 @raise TypeError: If either C{x} or C{y} are not iterables 661 """ 662 if not x or not y: 663 return 0.0 664 return self.entropy(x) + self.entropy(y) - self.entropy(frozenset(x)|frozenset(y))
665 666
667 - def total_count(self):
668 """Return the number of datapoints in the data 669 670 @rtype: The number of datapoints in the data 671 @return: Integer 672 """ 673 self.cursor.execute('SELECT sum(value) FROM %s' % self.table) 674 return self.cursor.fetchone()[0]
675 676
677 - def populate(self,records,variables=None):
678 """Simply inserts the records into the database 679 680 Each record is a tuple of integers. Each integer, except the last, corresponds to a value. 681 The last value is the count. If a joint instantiation occurs more than once, only the 682 last count is used. 683 684 Assumes lexicographic order of variables if C{variables} is None, otherwise the 685 order given by C{variables}. 686 687 """ 688 if variables is None: 689 variables = sorted(self._variables) 690 elif frozenset(variables) != self._variables: 691 raise ValueError('Got sent these variables: %s, expected these: %s' 692 % (variables, self._variables)) 693 cols = ','.join(list(variables)+['value']) 694 sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % ( 695 self.table,cols,','.join(['?']*(len(self._variables)+1))) 696 self.cursor.executemany(sql, records) 697 self._n = self.total_count()
698
699 - def qhs(self):
700 self.cursor.execute('SELECT value FROM %s' % self.table) 701 data = sorted([val[0] for val in self.cursor.fetchall()]) 702 hs = [0] * data[-1] 703 m = len(data) 704 j = 0 705 for k, count in enumerate(data): 706 for i in range(j,count): 707 hs[i] = m - k 708 j = count 709 return hs
710
711 - def ub(self,qpa,alpha,ri):
712 """The upper bound self provides on a the score of a smaller 713 parent set where 714 715 @param qpa: Size of contingency table for smaller parent set 716 @param alpha: Effective sample size 717 @param ri: Number of values of the child 718 """ 719 hs = self.qhs() 720 total = 0.0 721 a = alpha/float(qpa) 722 for h, qh in enumerate(hs): 723 total += qh*log((h+a/ri)/(h+a)) 724 #print total 725 return (float(qpa)/self.table_size())*total
726 727 728
729 - def qh(self,h=0):
730 """Return the number of instantiations (ie cells) having a value 731 greater than C{h} 732 733 @param h: Threshold 734 @type h: Integer 735 """ 736 self.cursor.execute('SELECT value FROM %s WHERE value > %d' % (self.table,h)) 737 return len(self.cursor.fetchall())
738
739 - def make_family_scores_naively(self,pa_size_lim=4,precision=10.0,batch_size=65536):
740 """ 741 Make scores for all parent sets for all variables where (1) the size 742 of the parent set is at most C{pa_size_lim}. No pruning! 743 """ 744 745 raise RuntimeError("This method is buggy, don't use!") 746 747 from gPy.Parameters import Factor 748 from gPy.Utils import subsetn_batch 749 from array import array 750 from gPyC import lgh 751 752 precision = float(precision) 753 batch_size = min(batch_size,65536) 754 755 marginal_size_lim = pa_size_lim+1 756 cursor = self.db.cursor() 757 758 h_score = {} 759 pa_scores = {} # where it all ends up 760 for v in self._variables: 761 pa_scores[v] = {} 762 763 if marginal_size_lim > len(self._variables)/2: 764 raise ValueError('pa_size %d too big (see subseteqn_batch docs)' % pa_size) 765 766 faset_batch_all = [] 767 def tmpfn(x): return len(x) == marginal_size_lim 768 769 variable_indices = range(len(self._variables)) 770 last_variable_index = variable_indices[-1] # to cope with 'missing' subsets 771 772 for faset_batch_all_tmp in subsetn_batch(variable_indices,marginal_size_lim,batch_size): 773 774 # associate each variable set of size marginal_size_lim with 'its' subsets 775 # subsets are attached to only one maximal sized set 776 777 faset_batch_all.extend(faset_batch_all_tmp) 778 779 #print len(faset_batch_all) 780 781 faset_batch = filter(tmpfn,faset_batch_all) 782 783 #print 'Done with subsets', len(faset_batch) 784 # initialisation 785 786 # with array('H') can store up to 65535 787 788 fasets_i_including = [array('H') for var in self._variables] 789 mults = [{} for var in self._variables] 790 current = ['0'] * len(self._variables) 791 num_fasets = len(faset_batch) 792 count = [None] * num_fasets 793 794 indx = array('I',[0] * num_fasets) 795 796 # compute stuff for each family set = marginal 797 798 for faset_i, faset in enumerate(faset_batch): 799 size = 1 800 for j in sorted(faset,reverse=True): # the jth variable, j is an int 801 fasets_i_including[j].append(faset_i) 802 mults[j][faset_i] = size 803 size *= self._numvals[self._sortedvariables[j]] 804 # dataset always smaller than 65535 805 count[faset_i] = array('H',[0] * size) 806 807 # processing data ... 808 809 #print 'About to process data' 810 811 cursor.execute('SELECT * FROM %s' % self.table) 812 for val, valcount in cursor.fetchall(): 813 814 # compute indices (sparse matrix-vector computation) 815 816 previous = current 817 current = val.split(',') 818 for j, inst_j in enumerate(current): 819 if inst_j == previous[j]: 820 continue 821 mult_j = mults[j] 822 diff_j = int(inst_j) - int(previous[j]) 823 for faset_i in fasets_i_including[j]: 824 indx[faset_i] += (diff_j * mult_j[faset_i]) 825 826 # increment counts 827 828 for faset_i, count_faset in enumerate(count): 829 count_faset[indx[faset_i]] += valcount 830 831 # compute scores 832 833 #print 'About to compute scores' 834 835 836 # faset_batch contains subsets of marginal_size_lim 837 # faset_batch_all contains all subsets *up to* marginal_size_lim 838 839 k = 0 840 for faset_i, faset in enumerate(faset_batch): 841 subsets = [] 842 while True: 843 subset = faset_batch_all[k] 844 k += 1 845 subsets.append(subset) 846 if subset == faset: 847 break 848 # to cope with some non-maximal size subsets coming just after 849 # their containing maximal size subset 850 if subsets[-1][-1] == last_variable_index: 851 biggest = frozenset(subsets[-1]) 852 while k < len(faset_batch_all): 853 next_subset = faset_batch_all[k] 854 if not biggest.issuperset(next_subset): 855 break 856 subsets.append(next_subset) 857 k += 1 858 #print subsets 859 factor_data = list(count[faset_i]) 860 factor_variables = [self._sortedvariables[j] for j in faset] 861 factor = Factor(factor_variables,factor_data,self) 862 for marginal in subsets: # marginal a set of variable indices 863 marginal_variables = frozenset([self._sortedvariables[j] for j in marginal]) 864 marginal_data = (factor._data_marginalise( 865 factor_data, 866 factor_variables, 867 factor._variables - marginal_variables)) 868 #print '\n^^^^^^^^^^^' 869 #print marginal_data 870 h_score[marginal_variables] = lgh(marginal_data,precision/len(marginal_data)) 871 #print marginal_variables, h_score[marginal_variables] 872 faset_batch_all = faset_batch_all[k:] 873 874 for family, family_score in h_score.items(): 875 for child in family: 876 parents = family-frozenset([child]) 877 pa_scores[child][parents] = family_score - h_score[parents] 878 879 # for testing 880 # for child, parentdict in pa_scores.items(): 881 # for parent, score in parentdict.items(): 882 # print child, parent, score, 883 # factor = self.makeFactor(parent | set([child])) 884 # print CPT(factor,child).bdeu_score(precision) 885 886 return pa_scores
887
888 - def makeFactorsn(self,n,block=1000000):
889 """Yield counts for all marginals with C{n} variables in blocks of C{block} 890 891 Marginals are ordered according to how they are generated by the generator 892 L{Utils.subseteqn}. 893 894 @param n: The number of variables in the marginals 895 @type n: Int 896 @return: C{counts} where C{counts[i][idx]} is the count for the C{idx}th instantiation of 897 the C{i}th marginal 898 @rtype: List 899 """ 900 901 from gPy.Utils import subseteqn 902 903 marginals_including = [[] for var in self._variables] 904 count = [None] * block 905 mults = [{} for var in self._variables] 906 for i, marginal in enumerate(subseteqn(range(len(self._variables)),n)): 907 modi = i % block 908 size = 1 909 for j in sorted(marginal,reverse=True): # the jth variable, j is an int 910 marginals_including[j].append(modi) 911 mults[j][modi] = size 912 size *= self._numvals[self._sortedvariables[j]] 913 count[modi] = [0] * size # so 'count' in synch with 'marginals' 914 if block - modi == 1: 915 print 'OK i is', i 916 yield self._countsfromdata(count,mults,marginals_including) 917 marginals_including = [[] for var in self._variables] 918 count = [None] * block 919 mults = [{} for var in self._variables] 920 # for the tail end 921 count = count[:modi+1] 922 yield self._countsfromdata(count,mults,marginals_including)
923
924 - def _countsfromdata(self,count,mults,marginals_including):
925 indx_template = [0] * len(count) 926 cursor = self._data.cursor() 927 cursor.execute('SELECT * FROM %s' % self.table) 928 for val, valcount in cursor.fetchall(): 929 indx = indx_template[:] 930 for j, inst_j in enumerate(eval(val)): 931 mult_j = mults[j] 932 for i in marginals_including[j]: 933 indx[i] += (inst_j * mult_j[i]) 934 for i, idx in enumerate(indx): 935 count[i][idx] += valcount 936 return count
937
938 - def makeFactorsn_old(self,n):
939 """Return counts for all marginals with C{n} variables 940 941 Marginals are ordered according to how they are generated by the generator 942 L{Utils.subseteqn}. 943 944 @param n: The number of variables in the marginals 945 @type n: Int 946 @return: C{counts} where C{counts[i][idx]} is the count for the C{idx}th instantiation of 947 the C{i}th marginal 948 @rtype: List 949 """ 950 951 # each marginal represented by an ordered tuple of *indices* 952 # 'id' of each marginal is just its index in the list 'marginals' 953 # marginals = [x for x in subseteqn(range(len(self._variables)),n)] 954 955 # marginals_including[j] contains (indexes of) all marginals containing jth variable 956 # count[i] is a list big enough to contain marginal counts 957 # for ith marginal 958 959 #for i, marginal in enumerate(marginals): 960 961 # mults[j] contains multiplers for the jth variable for each 962 # marginal containing it. Multipliers are ordered in synch 963 # with marginals_including 964 965 # count[i] will contain the counts for marginal i 966 967 from gPy.Utils import subseteqn 968 969 marginals_including = [[] for var in self._variables] 970 mults = [{} for var in self._variables] 971 972 count = [] 973 for i, marginal in enumerate(subseteqn(range(len(self._variables)),n)): 974 size = 1 975 for j in sorted(marginal,reverse=True): # the jth variable, j is an int 976 marginals_including[j].append(i) 977 mults[j][i] = size 978 size *= self._numvals[self._sortedvariables[j]] 979 print i, size 980 count.append([0] * size) # so 'count' in synch with 'marginals' 981 982 print 'this long', len(count) 983 indx_template = [0] * len(count) 984 985 cursor = self._data.cursor() 986 cursor.execute('SELECT * FROM %s' % self.table) 987 for val, valcount in cursor.fetchall(): 988 indx = indx_template[:] 989 for (inst_j,mult_j,interval_j,marginals_including_j) in zip( 990 eval(val),mults,self._intervals,marginals_including): 991 for i in marginals_including_j: 992 indx[i] += (inst_j * mult_j[i]) 993 for i, idx in enumerate(indx): 994 count[i][idx] += valcount 995 return count
996
997 - def family_score(self, child, parents, precision=1.0):
998 return self.makeCPT(child, parents, force_cpt=False, check=False).bdeu_score(precision)
999 1000
1001 - def score_adg(self,adg,precision=1.0):
1002 """Get Bdeu score for an adg""" 1003 score = 0.0 1004 for child in adg.vertices(): 1005 score += CPT( 1006 self.makeFactor(adg.parents(child) | set([child])), 1007 child).bdeu_score(precision) 1008 return score
1009
1010 - def h_scores(self,precision=1.0,textfun=str):
1011 from gPy.Utils import all_subsets 1012 from gPyC import lgh 1013 for varset in all_subsets(list(self._variables)): 1014 factor = self.makeFactor(varset) 1015 print len(varset), lgh(factor._data,precision/len(factor._data)), textfun(varset)
1016
1017 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
1018 """ 1019 @param prior: the Dirichlet prior parameter (the same parameter value 1020 is used for all instances!) Note there may be some problems with 1021 this method: a B{different} prior is used by the BDeu score. However, 1022 in practice, for parameter estimation, this prior method seems to be ok. 1023 I was lazy and it was simple to implement (cb). If prior is zero, then 1024 the parameters are the maximum likelihood estimation solutions. 1025 """ 1026 family = set(parents) | set([child]) 1027 f_child = self.makeFactor(family) 1028 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
1029 1030
1031 - def makeFactor(self,variables=None):
1032 """Simple way to get a factor 1033 1034 @param variables: The variables in the required factor. 1035 If C{None}, then all of C{self}'s variables are used: which 1036 could produce a very large object! 1037 @type variables: Iterable 1038 @return: Marginal table 1039 @rtype: L{Factor} object 1040 """ 1041 from gPy.Parameters import Factor 1042 1043 if variables is None: 1044 marginal = self 1045 else: 1046 marginal = self.marginal(variables) 1047 marginal.cursor.execute('SELECT * FROM %s' % marginal.table) 1048 1049 # compute step sizes for each variable 1050 st = 1 1051 step = [] 1052 for v in sorted(variables,reverse=True): 1053 step.append(st) 1054 st *= self._numvals[v] 1055 step.reverse() 1056 1057 data = [0] * marginal.table_size() 1058 for row in self.cursor.fetchall(): 1059 i = 0 1060 for j, val in enumerate(row[:-1]): 1061 i += val*step[j] 1062 data[i] = row[-1] 1063 return Factor(variables,data,domain=self)
1064 1065 @property
1066 - def n(self):
1067 """Return the current number of datapoints stored 1068 1069 @return: The current number of datapoints stored 1070 @rtype: Integer 1071 """ 1072 return self._n
1073 1074 1075 @property
1076 - def table(self):
1077 """Return the name of the table storing C{self}'s data 1078 1079 @return: The name of the table storing C{self}'s data 1080 @rtype: String 1081 """ 1082 return 'table%d' % id(self)
1083 1084 1085 __getitem__ = makeFactor
1086
1087 # def makeFactors(self,variablesets): 1088 # """Simple way to get factors""" 1089 # mults = [] 1090 # indicess = [] 1091 # datas = [] 1092 # for variables in variablesets: 1093 # size = 1 1094 # mult = {} 1095 # mults.append(mult) 1096 # indices = [] 1097 # indicess.append(indices) 1098 # for v in sorted(variables,reverse=True): # the jth variable, j is an int 1099 # mult[self._dkt[v]] = size 1100 # indices.append(self._dkt[v]) 1101 # size *= self._numvals[v] 1102 # datas.append([0] * size) 1103 1104 # cursor = self._data.cursor() 1105 # cursor.execute('SELECT * FROM %s' % self.table) 1106 # for val, valcount in cursor.fetchall(): 1107 # inst = [int(x) for x in val.split(',')] 1108 # for i, variables in enumerate(variablesets): 1109 # indx = 0 1110 # for j in indicess[i]: 1111 # indx += inst[j]*mults[i][j] 1112 # datas[i][indx] += valcount 1113 1114 # factors = [] 1115 # for i, variables in enumerate(variablesets): 1116 # factors.append(Factor(variables,datas[i],domain=self)) 1117 # return factors 1118 1119 1120 1121 1122 # def score_all(self,n,precision=1.0): 1123 # """Compute the 'lgamma' score associated with all marginal tables involving 1124 # C{n} variables. 1125 1126 # Needs a better name! 1127 # """ 1128 # from gPy.Utils import subseteqn 1129 # from array import array 1130 # indices_s = [x for x in subseteqn(range(len(self._variables)),n)] 1131 # tmp_numvals = [] 1132 # for v in sorted(self._variables): 1133 # tmp_numvals.append(self._numvals[v]) 1134 # dkts = [] 1135 # intervals = [] 1136 # for indices in indices_s: 1137 # dkts.append([0]*reduce(operator.mul,[tmp_numvals[j] for j in indices],1)) 1138 # intervals.append([1]*10) ## WRONG!! 1139 # cur = self._con.cursor() 1140 # cur.execute('SELECT * FROM %s' % (self._table)) 1141 # for row in cur.fetchall(): 1142 # val = row[-1] 1143 # for i, indices in enumerate(indices_s): 1144 # indx = 0 1145 # for k, j in enumerate(indices): 1146 # indx += row[j]*intervals[i][k] 1147 # dkts[i][indx] += val 1148 1149 # # perhaps make dkt a list and compute an integer index from [row[j] for j in indices] 1150 # #for j in indices: 1151 # # dkt = dkt[row[j]] 1152 # #dkt += val 1153 # #key = tuple([row[j] for j in indices]) 1154 # #try: 1155 # # dkt[key] += val 1156 # #except KeyError: 1157 # # dkt[key] = val 1158 1159 # scores = [0.0] * len(dkts) 1160 # for i, dkt in enumerate(dkts): 1161 # scores[i] = gPyC.lgh(dkt, 1162 # reduce(operator.mul,[tmp_numvals[j] for j in indices_s[i]],1)) 1163 # return scores, indices_s 1164 1165 # def score(self,variables,precision=1.0): 1166 # """Compute the 'lgamma' score associated with the marginal table involving 1167 # variables C{variables}. 1168 1169 # Needs a better name! 1170 1171 # """ 1172 # cur = self._con.cursor() 1173 # if variables: 1174 # cur.execute('SELECT sum(gpyval) FROM %s GROUP BY %s' % ( 1175 # self._table, ','.join(variables))) 1176 # else: 1177 # cur.execute('SELECT sum(gpyval) FROM %s' % self._table) 1178 # return gPyC.lgh( 1179 # [x[0] for x in cur.fetchall()], 1180 # precision/reduce(operator.mul,[self._numvals[v] for v in variables],1)) 1181 1182 # def bdeu_score(self,child,parents,precision=1.0): 1183 # """parents is a list""" 1184 # denom = self.score(parents,precision) 1185 # parents.append(child) 1186 # numer = self.score(parents,precision) 1187 # return numer - denom 1188 1189 1190 1191 -class CompactFactor(SubDomain):
1192 """Factors whose data is not explicity represented (because there's 1193 too much of it). 1194 1195 Generally used to store datasets. Implemented using ADTrees 1196 @ivar _tree: ADTree holding the data 1197 @type _tree: L{_ADTree} object 1198 @ivar _records: Each record is a tuple of integers. For each tuple the M{j}th element 1199 is the index of the value of the M{j}th variable found in that record. 1200 The final integer is a count of how often the record appeared. 1201 @type _records: List 1202 @ivar _var_index: TODO 1203 """
1204 - def __init__(self,data,domain=None,rmin=200):
1205 """Initialise a L{CompactFactor} from C{rawdata} 1206 1207 @param data: See L{Data} documentation 1208 @type data: 1209 @param domain: A domain for C{self}'s variables. If None 1210 the internal default domain is used. 1211 @type domain: Dictionary or None 1212 @param rmin: Controls the space/time tradeoff for using L{CompactFactor} objects. 1213 Subsets of records smaller than C{rmin} are stored as a tuple of (pointers to) the records 1214 concerned. 1215 @type rmin: int 1216 @raise KeyError: if C{rawdata} has a variable that has not been previously declared 1217 @raise TypeError: if C{rawdata} is not a sequence or its elements are not of 1218 the right type. 1219 @raise ValueError: if C{rawdata} does not have exactly 3 elements 1220 """ 1221 from gPy.IO import read_csv 1222 if data is not None: 1223 data_type = type(data) 1224 if data_type == tuple: 1225 rawdata = data 1226 elif data_type == str: 1227 if data.endswith('.gz'): 1228 import gzip 1229 rawdata = read_csv(gzip.open(data)) 1230 else: 1231 rawdata = read_csv(open(data)) 1232 elif data_type == file: 1233 rawdata = read_csv(data) 1234 else: 1235 raise TypeError('data should be file, string, tuple or None') 1236 else: 1237 raise ValueError('Need some data actually!') 1238 new_domain_variables, variables, records = rawdata[1:] 1239 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True) 1240 var_index = {} 1241 numvals = [] 1242 for i, variable in enumerate(variables): 1243 var_index[variable] = i 1244 numvals.append(self._numvals[variable]) 1245 n = 0 1246 for record in records: 1247 n += record[-1] 1248 self._tree = _ADTree(numvals,records,0,rmin) 1249 self._records = records 1250 self._var_index = var_index 1251 self._n = n
1252
1253 - def bic_search(self,child):
1254 """Branch and bound search for all parent sets for C{child} 1255 which do not have a higher scoring subset 1256 1257 TODO: use a random graph as an input 1258 """ 1259 n = self._n 1260 child_df = self._numvals[child] - 1 1261 old_parents = frozenset() 1262 further_parents = tuple(self._variables - frozenset([child])) 1263 store = [] 1264 self._bic_search2(child,n,child_df,old_parents,further_parents,store) 1265 return dict((s[1],-s[0]) for s in store)
1266
1267 - def _bic_search2(self,child,n,child_df,old_parents,further_parents,store):
1268 from heapq import heappush, heappop 1269 if further_parents: 1270 new_parent = further_parents[0] 1271 further_parents = further_parents[1:] 1272 1273 for new_parents in old_parents, old_parents|frozenset([new_parent]): 1274 1275 # for small parent sets don't bother trying to prune 1276 if len(new_parents) < 2: 1277 self._bic_search2(child,n,child_df,new_parents,further_parents,store) 1278 else: 1279 # is it worth continuing the branch without/with the new parent? 1280 # note that a new pruning opportunity may arise even if new_parents == old_parents 1281 # since the reduction in further_parents may decrease highest_llh enough 1282 highest_llh = -n * self.conditional_entropy([child],new_parents.union(further_parents)) 1283 lowest_penalty = log(n) * (child_df * self.table_size(new_parents)) / 2 1284 best_possible = highest_llh - lowest_penalty 1285 tmp = store[:] 1286 # negation of the score is stored in the heap 1287 score, scored_parents = heappop(tmp) 1288 while -score > best_possible: 1289 if scored_parents < new_parents: 1290 break 1291 score, scored_parents = heappop(tmp) 1292 else: 1293 self._bic_search2(child,n,child_df,new_parents,further_parents,store) 1294 else: 1295 dim = child_df * self.table_size(old_parents) 1296 complexity_penalty = (log(n) * dim) / 2 1297 bic_score = (-n * self.conditional_entropy([child],old_parents) 1298 - complexity_penalty) 1299 if not store: 1300 store.append((-bic_score,old_parents)) 1301 return 1302 tmp = store[:] 1303 # negation of the score is stored in the heap 1304 score, scored_parents = heappop(tmp) 1305 while -score > bic_score: 1306 if scored_parents < old_parents: 1307 break 1308 score, scored_parents = heappop(tmp) 1309 else: 1310 heappush(store,(-bic_score,old_parents))
1311
1312 - def conditional_entropy(self,x,y):
1313 """ 1314 Return the conditional entropy H(x|y) for variable sets C{x} and C{y} 1315 using the empirical distribution given by the data 1316 """ 1317 return self.entropy(frozenset(x)|frozenset(y)) - self.entropy(y)
1318
1319 - def entropy(self,x):
1320 """Return the entropy of the marginal empirical distribution given by C{x} 1321 and the data 1322 """ 1323 try: 1324 return self.entropy_cache[frozenset(x)] 1325 except AttributeError: 1326 pass 1327 except KeyError: 1328 pass 1329 1330 # empty set of variables has zero entropy 1331 if not x: 1332 return 0.0 1333 #factor = self.makeFactor(x) 1334 h = sum(count * log(count) for count in self.get_nonzerocounts(x)) 1335 n = self._n 1336 entropy = log(n) - h/n 1337 try: 1338 self.entropy_cache[frozenset(x)] = entropy 1339 except AttributeError: 1340 pass 1341 return entropy
1342 1343 1344
1345 - def __str__(self):
1346 return "Variables: %s\nTree:\n %s\n" % ( 1347 sorted(self._variables), self._tree)
1348
1349 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
1350 """ 1351 @param prior: the Dirichlet prior parameter (the same parameter value 1352 is used for all instances!) Note there may be some problems with 1353 this method: a B{different} prior is used by the BDeu score. However, 1354 in practice, for parameter estimation, this prior method seems to be ok. 1355 I was lazy and it was simple to implement (cb). If prior is zero, then 1356 the parameters are the maximum likelihood estimation solutions. 1357 """ 1358 family = set(parents) | set([child]) 1359 f_child = self.makeFactor(family) 1360 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
1361 1362
1363 - def get_nonzerocounts(self,variables):
1364 variables_info = [ 1365 [self._var_index[var],self._numvals[var]] 1366 for var in variables] 1367 variables_info.sort() 1368 #size is the length of the list under each value of the corresponding variable 1369 size = 1 1370 for variable_info in reversed(variables_info): 1371 variable_info.append(size) 1372 size *= variable_info[1] 1373 return self._tree._flatten(variables_info,0,no_zeroes=True)
1374
1375 - def makeFactor(self,variables):
1376 """Return a marginal factor with C{variables} 1377 1378 Unless C{variables} is empty in which case return 1379 the sum of C{self}'s data 1380 @param variables: Variables to project onto 1381 @type variables: Iterable 1382 @return: The marginal factor 1383 @rtype: L{Factor} object 1384 @raise KeyError: if C{variables} contains a variable not contained in the 1385 L{CompactFactor}. 1386 """ 1387 variables_info = [ 1388 [self._var_index[var],self._numvals[var]] 1389 for var in variables] 1390 variables_info.sort() 1391 #size is the length of the list under each value of the corresponding variable 1392 size = 1 1393 for variable_info in reversed(variables_info): 1394 variable_info.append(size) 1395 size *= variable_info[1] 1396 return Factor(variables,self._tree._flatten(variables_info,0),self)
1397 1398 __getitem__ = makeFactor 1399
1400 - def tree_size(self):
1401 """Return the number of nodes in the underlying ADTree 1402 1403 @return: The number of nodes in the tree 1404 @rtype: int 1405 """ 1406 return self._tree.size()
1407
1408 1409 -class _ADTree(object):
1410 """ADTree implementation 1411 1412 Ref:: 1413 1414 @Article{moore98:_cached_suffic_statis_effic_machin, 1415 author = {Andrew Moore and Mary Soon Lee}, 1416 title = {Cached Sufficient Statistics for Efficient Machine Learning with Large Datasets}, 1417 journal = {Journal of Artificial Intelligence Research}, 1418 year = 1998, 1419 volume = 8, 1420 pages = {67--91}, 1421 url = {http://www.jair.org/media/453/live-453-1678-jair.pdf} 1422 } 1423 1424 @ivar _count: A count of the number of records 'in' the tree 1425 @type _count: int 1426 @ivar _data: may be 1427 .1 a tuple each element of which is either an _ADTree object or None, 1428 There is one element for each value of the variable corresponding to 1429 the top node of the tree. A None value indicates 0 records for the corresponding 1430 value 1431 .2 a tuple of records 1432 .3 a single _ADTree object. This is a space saving mechanism used when there 1433 is only one value of the variable with any records associated with it 1434 @type _data: Various 1435 @ivar _mcvindex: In cases 1) and 3) this is an integer stating which value has the most records. 1436 To this value *all* records are associated rather than just those with the appropriate 1437 value. In case 2) this is None 1438 @type _mcvindex: Various 1439 """ 1440 1441 __slots__ = ('_data','_count','_mcvindex') 1442
1443 - def __getstate__(self):
1444 return self._data, self._count, self._mcvindex
1445
1446 - def __setstate__(self,state):
1447 self._data, self._count, self._mcvindex = state
1448 1449
1450 - def __init__(self,numvals,records,depth,rmin):
1451 """Initialise an L{_ADTree} object 1452 1453 @param numvals: The ith element of this list is the number of values for the ith variable 1454 in the tree 1455 @type numvals: list 1456 @param records: The data as a list. Each element is a tuple of value indices, one for each variable, 1457 plus an extra count field as the final element. 1458 @type records: list 1459 @param depth: The depth of this tree within its containing CompactFactor. Also the index of the variable 1460 associated with the top node of this tree 1461 @type depth: int 1462 @param rmin: If the number of records is below C{rmin} then tree growing stops and C{records} is stored. 1463 Note that a single record may represent many datapoints since all records have an extra 'count' field. 1464 @type rmin: int 1465 """ 1466 1467 # all nodes have a count 1468 count = 0 1469 # each record has an extra 'count' field at the end 1470 for record in records: 1471 count += record[-1] 1472 self._count = count 1473 1474 if depth == len(numvals): 1475 # no more variables to process 1476 self._data = () 1477 self._mcvindex = None 1478 return 1479 1480 if len(records) < rmin: 1481 # Case 2) above 1482 self._data = records 1483 self._mcvindex = None 1484 return 1485 1486 #set up 1487 tmp = [] 1488 counts = [] 1489 for value in range(numvals[depth]): 1490 tmp.append([]) 1491 counts.append(0) 1492 1493 #distribute records 1494 for record in records: 1495 valindex = record[depth] 1496 tmp[valindex].append(record) 1497 counts[valindex] += record[-1] 1498 1499 # find most common value 1500 mcvcount = 0 1501 for valindex, count in enumerate(counts): 1502 if count > mcvcount: 1503 self._mcvindex, mcvcount = valindex, count 1504 # mcv branch actually has all records 1505 # not just those for which first=mcv 1506 tmp[self._mcvindex] = records 1507 1508 # make branches 1509 data = [] 1510 found_not_empty_before = False 1511 1512 for value_records in tmp: 1513 if value_records == []: 1514 data.append(None) 1515 else: 1516 data.append(_ADTree(numvals,value_records,depth+1,rmin)) 1517 if found_not_empty_before: 1518 one_not_empty = False 1519 else: 1520 found_not_empty_before = True 1521 one_not_empty = True 1522 1523 if one_not_empty: 1524 # Case 3) above 1525 self._data = data[self._mcvindex] 1526 else: 1527 # Case 1) above 1528 self._data = tuple(data)
1529
1530 - def _flatten(self,variables_info,depth):
1531 """Return the data for a L{Factor} 1532 1533 The L{Factor}'s variables will generally be a subset of those 1534 for which data is stored in the tree 1535 1536 @param variables_info: Contains the necessary information on the variables 1537 sought without naming them. The ith element of C{variables_info} contains information on 1538 the ith variable sought. Each element of C{variables_info} is a 3 element list: 1539 variables_info[i][0] is the depth in the L{_ADTree} which deals with the ith variable sought. 1540 variables_info[i][1] is the number of values of the ith variable sought. 1541 variables_info[i][2] is the number of data values in the eventual factor which correspond 1542 to each value of the ith variable sought. 1543 (Clearly this depends on the number of values of 'later' variables.) 1544 @type variables_info: list 1545 @return: Values for the factor 1546 @rtype: list 1547 """ 1548 if variables_info == []: 1549 # no need to descend tree further 1550 data = [self._count] 1551 elif self._mcvindex is None: 1552 # Case 2) above 1553 # self._data is a list of records 1554 # no need to descend tree further 1555 numvals = variables_info[0][1] 1556 size = variables_info[0][2] 1557 data = [0] * (numvals*size) 1558 for record in self._data: 1559 i = 0 1560 for variable_info in variables_info: 1561 i += record[variable_info[0]] * variable_info[2] 1562 data[i] += record[-1] 1563 elif variables_info[0][0] == depth: 1564 # have reached the depth for the first variable 1565 numvals = variables_info[0][1] 1566 size = variables_info[0][2] 1567 mcvstart = self._mcvindex * size 1568 mcvend = mcvstart + size 1569 if isinstance(self._data,tuple): 1570 # Case 1) above 1571 data = [] 1572 acc = [0] * size 1573 for i, branch in enumerate(self._data): 1574 if branch is None: 1575 data.extend([0] * size) 1576 else: 1577 this_data = branch._flatten(variables_info[1:],depth+1) 1578 data.extend(this_data) 1579 if i != self._mcvindex: 1580 acc = map(operator.add,acc,this_data) 1581 # 'correct' data values for mcvindex 1582 data[mcvstart:mcvend] = map(operator.sub,data[mcvstart:mcvend],acc) 1583 else: 1584 # Case 3) above 1585 data = [0] * (size * numvals) 1586 data[mcvstart:mcvend] = self._data._flatten(variables_info[1:],depth+1) 1587 else: 1588 # Keep looking for the depth for first variable .. 1589 if isinstance(self._data,tuple): 1590 branch = self._data[self._mcvindex] 1591 else: 1592 branch = self._data 1593 data = branch._flatten(variables_info,depth+1) 1594 return data
1595 1596 1597 ## def _non_zerocounts(self,variables_info,depth): 1598 ## """Return the data for a L{Factor} 1599 1600 ## The L{Factor}'s variables will generally be a subset of those 1601 ## for which data is stored in the tree 1602 1603 ## @param variables_info: Contains the necessary information on the variables 1604 ## sought without naming them. The ith element of C{variables_info} contains information on 1605 ## the ith variable sought. Each element of C{variables_info} is a 3 element list: 1606 ## variables_info[i][0] is the depth in the L{_ADTree} which deals with the ith variable sought. 1607 ## variables_info[i][1] is the number of values of the ith variable sought. 1608 ## variables_info[i][2] is the number of data values in the eventual factor which correspond 1609 ## to each value of the ith variable sought. 1610 ## (Clearly this depends on the number of values of 'later' variables.) 1611 ## @type variables_info: list 1612 ## @return: Values for the factor 1613 ## @rtype: list 1614 ## """ 1615 ## if variables_info == []: 1616 ## # no need to descend tree further 1617 ## if self._count > 0: 1618 ## data = [self._count] 1619 ## else: 1620 ## data = [] 1621 ## elif self._mcvindex is None: 1622 ## # Case 2) above 1623 ## # self._data is a list of records 1624 ## # no need to descend tree further 1625 ## data = {} 1626 ## for record in self._data: 1627 ## key = tuple([record[variable_info[0]] for variable_info in variables_info]) 1628 ## try: 1629 ## data[key] += record[-1] 1630 ## except KeyError: 1631 ## data[key] = record[-1] 1632 ## data = data.values() 1633 ## elif variables_info[0][0] == depth: 1634 ## # have reached the depth for the first variable 1635 ## numvals = variables_info[0][1] 1636 ## size = variables_info[0][2] 1637 ## mcvstart = self._mcvindex * size 1638 ## mcvend = mcvstart + size 1639 ## if isinstance(self._data,tuple): 1640 ## # Case 1) above 1641 ## data = [] 1642 ## acc = [0] * size 1643 ## for i, branch in enumerate(self._data): 1644 ## if branch is None: 1645 ## data.extend([0] * size) 1646 ## else: 1647 ## this_data = branch._flatten(variables_info[1:],depth+1) 1648 ## data.extend(this_data) 1649 ## if i != self._mcvindex: 1650 ## acc = map(operator.add,acc,this_data) 1651 ## # 'correct' data values for mcvindex 1652 ## data[mcvstart:mcvend] = map(operator.sub,data[mcvstart:mcvend],acc) 1653 ## else: 1654 ## # Case 3) above 1655 ## data = [0] * (size * numvals) 1656 ## data[mcvstart:mcvend] = self._data._flatten(variables_info[1:],depth+1) 1657 ## else: 1658 ## # Keep looking for the depth for first variable .. 1659 ## if isinstance(self._data,tuple): 1660 ## branch = self._data[self._mcvindex] 1661 ## else: 1662 ## branch = self._data 1663 ## data = branch._flatten(variables_info,depth+1) 1664 ## return data 1665 1666
1667 - def size(self):
1668 """Return the number of nodes in the tree 1669 1670 @return: The number of nodes in the tree 1671 @rtype: int 1672 """ 1673 if self._mcvindex is None: 1674 return 1 1675 count = 0 1676 if isinstance(self._data,tuple): 1677 for branch in self._data: 1678 if branch is None: 1679 count += 1 1680 else: 1681 count += branch.size() 1682 else: 1683 count += self._data.size() 1684 return count
1685
1686 - def __str__(self):
1687 return "(%s,%s,[%s])\n" % ( 1688 self._count, 1689 self._mcvindex, 1690 self._data)
1691
1692 1693 -def _merge_records(destination,source,how):
1694 """Given two sets of records, combing them according to the function C{how} 1695 @param destination: the target set of records where the result is stored 1696 and the left hand side of C{how} 1697 @type destination: list of tuples whose last element is a positive integer 1698 @param source: the right hand side of C{how} 1699 @type source: list of tuples whose last element is a positive integer 1700 @param how: operator to combine C{destination} and C{source} 1701 @type how: binary function 1702 """ 1703 # copy-on-write 1704 destination = list(destination) 1705 # merge together. (TODO: make faster by sorting and merging) 1706 for src_rec in source: 1707 src_prefix = src_rec[:-1] 1708 for i, dst_rec in enumerate(destination): 1709 if src_prefix == dst_rec[:-1]: 1710 destination[i] = destination[i][:-1] + (how(destination[i][-1], src_rec[-1]),) 1711 break 1712 else: 1713 destination.append(src_rec) 1714 return destination
1715
1716 -def _distribute_records(numvals, depth, records):
1717 """For each child at depth C{depth} generate the list of 1718 records (C{children_records}) and the number of instances (C{counts}). 1719 @param numvals: The ith element of this list is the number of values for 1720 the ith variable in the tree 1721 @type numvals: list 1722 @param records: The data as a list. Each element is a tuple of value 1723 indices, one for each variable, plus an extra count field as the final 1724 element. 1725 @type records: iterable 1726 @param depth: The depth of this tree within its containing CompactFactor. 1727 Also the index of the variable associated with the top node of this tree 1728 @type depth: int 1729 @return: A 2-tuple: C{children_records} and C{counts}, as described above. 1730 @rtype: tuple 1731 """ 1732 children_records = [[] for i in xrange(numvals)] 1733 counts = [0] * numvals 1734 for record in records: 1735 valindex = record[depth] 1736 children_records[valindex].append(record) 1737 counts[valindex] += record[-1] 1738 1739 return children_records, counts
1740
1741 -class _IncrementalADTree(_ADTree):
1742 1743 __slots__ = ('_data','_count','_mcvindex') 1744
1745 - def __init__(self, numvals, records, rmin, depth=0):
1746 if numvals is None and records is None and rmin is None: 1747 # copy construct 1748 self._data = None 1749 self._count = None 1750 self._mcvindex = None 1751 return 1752 self._make_new_node(numvals, records, rmin, depth)
1753
1754 - def __str__(self, descend=True):
1755 s = '{%x:' % id(self) 1756 s += ' count: '+str(self._count)+', mcvindex: '+str(self._mcvindex)+', data: ' 1757 if isinstance(self._data,tuple) or isinstance(self._data,list): 1758 s += '('+str(len(self._data))+')' 1759 if descend: 1760 if isinstance(self._data,tuple) or isinstance(self._data,list): 1761 s += '[' 1762 for child in self._data: 1763 if isinstance(child,_IncrementalADTree): 1764 s += child.__str__(descend=False) + ' ' 1765 else: 1766 s += str(child) + ' ' 1767 s += ']' 1768 elif isinstance(self._data,_IncrementalADTree): 1769 s += self._data.__str__(descend=False) 1770 else: 1771 s += str(self._data) 1772 else: 1773 s += '...skipped...' 1774 s += '}' 1775 return s
1776
1777 - def str(self, numvals, rmin, n=0, mcv=True):
1778 prefix = str(n) + ' '*n 1779 s = prefix+'count: '+str(self._count)+', mcv:'+str(self._mcvindex)+'\n' 1780 if self._mcvindex is None: 1781 # record 1782 s += prefix+'records\n' 1783 s += prefix+str(self._data)+'\n' 1784 elif not isinstance(self._data,tuple): 1785 # singleton node 1786 s += prefix+'*'+str(self._mcvindex)+': singleton\n' 1787 if mcv: 1788 data = self._build_mcv(numvals, rmin, n+1) 1789 else: 1790 data = self._data 1791 if data is None: 1792 s += prefix+' '+'empty\n' 1793 else: 1794 s += data.str(numvals,rmin,n+1,mcv) 1795 else: 1796 # nodes 1797 for i, data in enumerate(self._data): 1798 if i == self._mcvindex: 1799 s += prefix+'*'+str(i)+':\n' 1800 if mcv: 1801 data = self._build_mcv(numvals,rmin,n+1) 1802 else: 1803 s += prefix+' '+str(i)+':\n' 1804 if data is None: 1805 s += prefix+' '+'empty\n' 1806 else: 1807 s += data.str(numvals,rmin,n+1,mcv) 1808 return s
1809
1810 - def copy(self):
1811 copy = _IncrementalADTree(None,None,None) 1812 copy._count = self._count 1813 copy._mcvindex = self._mcvindex 1814 if self._mcvindex is None: 1815 # records are copy-on-write (by _merge_records) 1816 copy._data = self._data 1817 else: 1818 if self._have_compact_children(): 1819 copy._data = self._data.copy() 1820 else: 1821 copy._data = [] 1822 for child in self._data: 1823 if child is None: 1824 copy._data.append(None) 1825 else: 1826 copy._data.append(child.copy()) 1827 copy._data = tuple(copy._data) 1828 assert self._have_compact_children() == copy._have_compact_children() 1829 return copy
1830
1831 - def _compact_children(self, other_than_mcv=None):
1832 """Where possible convert self._data from a tuple (or list) into 1833 a singleton. other_than_mcv is a hint. If other than None, it indicates 1834 whether some child other than the MCV has a non-zero count.""" 1835 if self._mcvindex is None: 1836 return 1837 1838 if not isinstance(self._data,tuple) and not isinstance(self._data,list): 1839 return 1840 1841 if other_than_mcv is None: 1842 for i, child in enumerate(self._data): 1843 if child is None or i == self._mcvindex: 1844 continue 1845 other_than_mcv = True 1846 break 1847 else: 1848 other_than_mcv = False 1849 elif not other_than_mcv: 1850 for i, child in enumerate(self._data): 1851 if i == self._mcvindex: 1852 continue 1853 assert child is None 1854 1855 if not other_than_mcv: 1856 assert isinstance(self._data[self._mcvindex],_IncrementalADTree) 1857 self._data = self._data[self._mcvindex] 1858 elif isinstance(self._data,list): 1859 self._data = tuple(self._data)
1860 1861
1862 - def _have_compact_children(self):
1863 return self._mcvindex is not None and isinstance(self._data,_IncrementalADTree)
1864
1865 - def _update(self, numvals, records, rmin, depth=0):
1866 """Add data to an L{_ADTree} object 1867 1868 @param numvals: The ith element of this list is the number of values 1869 for the ith variable in the tree 1870 @type numvals: list 1871 @param records: The data as a list. Each element is a tuple of value 1872 indices, one for each variable, plus an extra count field as the final 1873 element. 1874 @type records: list 1875 @param depth: The depth of this tree within its containing 1876 CompactFactor. Also the index of the variable associated with the top 1877 node of this tree 1878 @type depth: int 1879 @param rmin: If the number of records is below C{rmin} then tree 1880 growing stops and C{records} is stored. Note that a single record may 1881 represent many datapoints since all records have an extra 'count' 1882 field. 1883 @type rmin: int 1884 """ 1885 1886 for record in records: 1887 assert len(record)-1 == len(numvals) 1888 assert isinstance(record,tuple) 1889 for element in record: 1890 assert isinstance(element,int) 1891 1892 # how many more? 1893 new_count = 0 1894 # each record has an extra 'count' field at the end 1895 for record in records: 1896 # all counts *must* be non-negative 1897 new_count += record[-1] 1898 1899 # clearly if there are no new observations, there is no update. 1900 if new_count == 0: 1901 return 1902 1903 self._count += new_count 1904 1905 if depth == len(numvals): 1906 # no more variables to process 1907 self._data = [] 1908 self._mcvindex = None 1909 return 1910 1911 if self._mcvindex is None: 1912 # a bunch of records or empty 1913 merged_records = _merge_records(records, self._data, operator.add) 1914 1915 # (XXX: repeated self._count calculation?) 1916 self._make_new_node(numvals, merged_records, rmin, depth) 1917 return 1918 1919 # distribute records (XXX: hmm, calculate counts twice as above!) 1920 children_records, counts = _distribute_records(numvals[depth], depth, records) 1921 counts[self._mcvindex] = new_count 1922 children_records[self._mcvindex] = records 1923 1924 # make self._data into a mutable list... 1925 if self._have_compact_children(): 1926 # check to see if more than one non-empty subtree 1927 # (which must be expanded) 1928 for i in xrange(len(counts)): 1929 if i == self._mcvindex: 1930 continue 1931 if counts[i] > 0: 1932 other_than_mcv = True 1933 break 1934 else: 1935 other_than_mcv = False 1936 1937 # if not, then don't expand -- just update 1938 if not other_than_mcv: 1939 self._data._update(numvals,records,rmin,depth+1) 1940 return 1941 1942 # else expand into full child set 1943 data = [None] * len(counts) 1944 data[self._mcvindex] = self._data 1945 self._data = data 1946 data = None 1947 else: 1948 self._data = list(self._data) 1949 1950 # self._data is an instance of tuple -- update each branch 1951 for i, child_records in enumerate(children_records): 1952 child = self._data[i] 1953 if counts[i] == 0: 1954 continue 1955 1956 if child is not None: 1957 child._update(numvals, child_records, rmin, depth+1) 1958 else: 1959 self._data[i] = _IncrementalADTree(numvals, child_records, rmin, depth+1) 1960 1961 self._compact_children() 1962 if self._have_compact_children(): 1963 assert self._count == self._data._count 1964 else: 1965 assert self._count == self._data[self._mcvindex]._count 1966 1967 # the above update may result in the most common value no longer 1968 # being the most common value 1969 self._correct_mcv(numvals, rmin, depth) 1970 self._verify(numvals,rmin,depth)
1971
1972 - def _make_new_node(self,numvals,records,rmin,depth, force_expand=False):
1973 """Turn C{self} into a fresh node from records 1974 1975 @param numvals: The ith element of this list is the number of values 1976 for the ith variable in the tree 1977 @type numvals: list 1978 @param records: The data as a list. Each element is a tuple of value 1979 indices, one for each variable, plus an extra count field as the final 1980 element. 1981 @type records: list 1982 @param depth: The depth of this tree within its containing 1983 CompactFactor. Also the index of the variable associated with the top 1984 node of this tree 1985 @type depth: int 1986 @param rmin: If the number of records is below C{rmin} then tree growing stops and C{records} is stored. 1987 Note that a single record may represent many datapoints since all records have an extra 'count' field. 1988 @type rmin: int 1989 """ 1990 1991 for record in records: 1992 assert len(record)-1 == len(numvals) 1993 assert isinstance(record,tuple) 1994 for element in record: 1995 assert isinstance(element,int) 1996 1997 # all nodes have a count 1998 self._count = 0 1999 # each record has an extra 'count' field at the end 2000 for record in records: 2001 self._count += record[-1] 2002 2003 if depth == len(numvals): 2004 # no more variables to process 2005 self._data = [] 2006 self._mcvindex = None 2007 return 2008 2009 if not force_expand and len(records) < rmin: 2010 # Case 2) above 2011 self._data = records 2012 self._mcvindex = None 2013 return 2014 2015 children_records, counts = _distribute_records(numvals[depth], depth, records) 2016 2017 # find most common value 2018 mcvcount = -1 2019 for valindex, count in enumerate(counts): 2020 if count > mcvcount: 2021 self._mcvindex, mcvcount = valindex, count 2022 # mcv branch actually has all records 2023 # not just those for which first=mcv 2024 children_records[self._mcvindex] = records 2025 counts[self._mcvindex] = self._count 2026 2027 # make branches 2028 self._data = [] 2029 other_than_mcv = False 2030 2031 for i, child_records in enumerate(children_records): 2032 if counts[i] == 0: 2033 self._data.append(None) 2034 continue 2035 2036 other_than_mcv |= i != self._mcvindex 2037 self._data.append(_IncrementalADTree(numvals,child_records, rmin, depth+1)) 2038 2039 self._compact_children(other_than_mcv) 2040 if self._have_compact_children(): 2041 assert self._count == self._data._count 2042 else: 2043 assert self._count == self._data[self._mcvindex]._count 2044 self._verify(numvals,rmin,depth)
2045
2046 - def _expand_one_level(self, numvals, rmin, depth):
2047 """self is a record carrying node but we want to treat it like 2048 something bigger. Expand just this level. This is a special 2049 case of _make_new_node.""" 2050 assert self._mcvindex is None 2051 2052 self._make_new_node(numvals, self._data, rmin, depth, force_expand=True)
2053
2054 - def _correct_mcv(self,numvals,rmin,depth):
2055 if self._mcvindex is None: 2056 return 2057 2058 if self._have_compact_children(): 2059 self._data._correct_mcv(numvals,rmin,depth+1) 2060 return 2061 2062 assert self._count == self._data[self._mcvindex]._count 2063 2064 # calculate sum of all non-mcv values and find the most frequent of these values 2065 cur_mcv_count = self._count 2066 max_mcv = self._mcvindex 2067 max_mcv_count = 0 2068 for i, child in enumerate(self._data): 2069 if child is None or i == self._mcvindex: 2070 continue 2071 cur_mcv_count -= child._count 2072 if child._count > max_mcv_count: 2073 max_mcv, max_mcv_count = i, child._count 2074 2075 # check if the current mcv is the true mcv... 2076 if cur_mcv_count < max_mcv_count: 2077 # if not, then recalculate the tree. generate 2078 # the tree of the wrong mcv, and move the vary node 2079 # stored at the wrong mcv to the correct mcv. 2080 # NOTE: this differs from the paper as the paper has 2081 # a mistake in it (wrong mcv and new mcv are the wrong way 2082 # around). 2083 data = list(self._data) 2084 data[self._mcvindex] = self._build_mcv(numvals, rmin, depth+1) 2085 assert (cur_mcv_count == 0 and data[self._mcvindex] is None) or data[self._mcvindex]._count == cur_mcv_count 2086 data[max_mcv] = self._data[self._mcvindex] 2087 self._data = data 2088 data = None 2089 self._mcvindex = max_mcv 2090 assert self._count == self._data[self._mcvindex]._count 2091 2092 self._compact_children(cur_mcv_count != 0 or self._count > max_mcv_count) 2093 if self._have_compact_children(): 2094 self._data._correct_mcv(numvals,rmin,depth+1) 2095 return 2096 2097 # correct all the subtrees and other vary nodes 2098 for child in self._data: 2099 if child is not None: 2100 child._correct_mcv(numvals, rmin, depth+1)
2101
2102 - def _build_mcv(self, numvals, rmin, depth):
2103 assert self._mcvindex is not None 2104 2105 # calculate the count of the mcv -- this is done before 2106 # the potentially more expensive full mcv calculation in case 2107 # it is zero, and for the case when there are no other variables 2108 # to consider. 2109 mcv_count = self._count 2110 if self._have_compact_children(): 2111 vary_child = self._data 2112 children = [] 2113 else: 2114 for i, child in enumerate(self._data): 2115 if child is None or i == self._mcvindex: 2116 continue 2117 mcv_count -= child._count 2118 vary_child = self._data[self._mcvindex] 2119 children = [child for i, child in enumerate(self._data) if i != self._mcvindex] 2120 2121 if mcv_count == 0: 2122 return None 2123 2124 assert mcv_count > 0 2125 2126 # end of the line (no more variables)? 2127 if depth == len(numvals): 2128 mcv = _IncrementalADTree(numvals, [], 1, depth) 2129 mcv._count = mcv_count 2130 return mcv 2131 2132 # calculate the full tree of the most common value 2133 mcv = vary_child.copy() 2134 for child in children: 2135 mcv._subtract(numvals,rmin,depth,child) 2136 2137 # sanity check 2138 assert mcv._count == mcv_count 2139 2140 return mcv
2141
2142 - def _subtract(self, numvals, rmin, depth, lesser):
2143 2144 if lesser is None: 2145 return 2146 2147 # counts must be non-negative 2148 assert self._count >= lesser._count 2149 2150 # expand record nodes if one of the nodes is not a record node 2151 if self._mcvindex is None and lesser._mcvindex is not None: 2152 self._expand_one_level(numvals, rmin, depth) 2153 2154 # perform the subtraction 2155 self._count -= lesser._count 2156 2157 # termination criterion 2158 if depth == len(numvals): 2159 return 2160 2161 # deal with record nodes 2162 if self._mcvindex is None and lesser._mcvindex is None: 2163 self._data = _merge_records(self._data, lesser._data, operator.sub) 2164 return 2165 2166 # expand record nodes if one of the nodes is not a record node 2167 if lesser._mcvindex is None: 2168 # use a copy -- that way we don't make lesser any 2169 # larger 2170 lesser = lesser.copy() 2171 lesser._expand_one_level(numvals, rmin, depth) 2172 2173 self_single = not isinstance(self._data,tuple) 2174 less_single = not isinstance(lesser._data,tuple) 2175 2176 # self must be larger than lesser 2177 if self_single and not less_single: 2178 # see if we can make less single by correcting 2179 # the mcv 2180 lesser._correct_mcv(numvals, rmin, depth) 2181 less_single = not isinstance(lesser._data,tuple) 2182 2183 assert not self_single or less_single 2184 2185 if self_single and less_single: 2186 assert self._mcvindex == lesser._mcvindex 2187 self_vary = self._data 2188 less_vary = lesser._data 2189 assert self_vary is None or self_vary._count == self._count + lesser._count 2190 assert less_vary is None or less_vary._count == lesser._count 2191 elif less_single: 2192 self_vary = self._data[self._mcvindex] 2193 less_vary = lesser._data 2194 assert self_vary is None or self_vary._count == self._count + lesser._count 2195 assert less_vary is None or less_vary._count == lesser._count 2196 if self._mcvindex != lesser._mcvindex: 2197 b = lesser._build_mcv(numvals, rmin, depth+1) 2198 if b is not None or b._count != 0: 2199 assert self._data[lesser._mcvindex] is not None 2200 assert self._data[lesser._mcvindex]._count != 0 2201 if self._data[lesser._mcvindex]._count == b._count: 2202 self._data = list(self._data) 2203 self._data[lesser._mcvindex] = None 2204 self._compact_children() 2205 else: 2206 self._data[lesser._mcvindex]._subtract(numvals,rmin, depth+1,b) 2207 else: 2208 self_vary = self._data[self._mcvindex] 2209 less_vary = lesser._data[lesser._mcvindex] 2210 assert self_vary is None or self_vary._count == self._count + lesser._count 2211 assert less_vary is None or less_vary._count == lesser._count 2212 # subtract the children 2213 self._data = list(self._data) 2214 for i, (a, b) in enumerate(zip(self._data, lesser._data)): 2215 if i == self._mcvindex: 2216 continue 2217 2218 if i == lesser._mcvindex: 2219 b = lesser._build_mcv(numvals, rmin, depth+1) 2220 2221 if b is None or b._count == 0: 2222 continue 2223 2224 assert a is not None and a._count != 0 2225 if a._count == b._count: 2226 self._data[i] = None 2227 else: 2228 a._subtract(numvals, rmin, depth+1, b) 2229 2230 self._compact_children() 2231 2232 # check vary node count consistency... 2233 assert self_vary is None or self_vary._count == self._count + lesser._count 2234 assert less_vary is None or less_vary._count == lesser._count 2235 2236 if self_vary is not None: 2237 self_vary._subtract(numvals, rmin, depth+1, less_vary) 2238 else: 2239 assert less_vary._count == 0 2240 2241 # check vary node count consistency... 2242 assert self_vary is None or self_vary._count == self._count 2243 assert less_vary is None or less_vary._count == lesser._count
2244
2245 - def _verify(self, numvals, rmin, depth=0):
2246 return
2247 # slow running code
2248 #S """Ensure that the tree is in a consistent state. This imposes a set of 2249 #S structurally necessary (but not sufficient) conditions for a correct 2250 #S result.""" 2251 #S if not __debug__: 2252 #S return 2253 #S 2254 #S # an adtree represents a non-negative conditional contingency table 2255 #S assert self._count >= 0 2256 #S 2257 #S # only the root node may have a zero count 2258 #S assert self._count != 0 or depth == 0 2259 #S 2260 #S # if all at depth equal to the number of variables 2261 #S if depth == len(numvals): 2262 #S # it had better be that there is no mcv (since this must be a leaf 2263 #S # node) 2264 #S assert self._mcvindex is None 2265 #S # and there must be no associated data. 2266 #S assert self._data == [] 2267 #S return 2268 #S 2269 #S # if this node counts less than rmin instances 2270 #S if self._mcvindex is None: 2271 #S # and data must be a list of tuples... 2272 #S assert isinstance(self._data,list) 2273 #S # containing some count data (or this must be a zero count)... 2274 #S assert len(self._data) > 0 or self._count == 0 2275 #S # and these data must be tuples whose last 2276 #S # element is a non-negative integer 2277 #S s = 0 2278 #S for d in self._data: 2279 #S assert isinstance(d,tuple) 2280 #S assert d[-1] >= 0 2281 #S s += d[-1] 2282 #S # ensure correct count 2283 #S assert self._count == s 2284 #S return 2285 #S 2286 #S # a non-leaf node had better have something interesting in it... 2287 #S assert self._count > 0 2288 #S 2289 #S # is this a singleton internal node? 2290 #S if not isinstance(self._data, tuple): 2291 #S # better be an ad tree ... 2292 #S assert isinstance(self._data, _IncrementalADTree) 2293 #S vary_child = self._data 2294 #S else: 2295 #S # otherwise, this had better have a valid most common value 2296 #S assert self._mcvindex >= 0 and self._mcvindex < numvals[depth] 2297 #S # and enough children 2298 #S assert isinstance(self._data,tuple) 2299 #S assert len(self._data) == numvals[depth] 2300 #S # the sum of every non-mcv child must be less than the count of 2301 #S # this node 2302 #S children = [child for i, child in enumerate(self._data) 2303 #S if i != self._mcvindex and child is not None] 2304 #S children_counts = [child._count for child in children] 2305 #S assert sum(children_counts) <= self._count 2306 #S # there must exists one element other than the mcv that is not None 2307 #S assert len(children) > 0 2308 #S mcv_count = self._count 2309 #S # verify the non-mcv children 2310 #S for child in children: 2311 #S child._verify(numvals, rmin, depth+1) 2312 #S mcv_count -= child._count 2313 #S assert mcv_count >= max(children_counts) 2314 #S vary_child = self._data[self._mcvindex] 2315 #S 2316 #S # the vary child (stored at the index of the most common value) must 2317 #S # have the same count as this node 2318 #S if vary_child is not None: 2319 #S assert self._count == vary_child._count 2320 #S else: 2321 #S assert self._count == 0 2322 #S # verify the vary child 2323 #S vary_child._verify(numvals, rmin, depth+1) 2324 2325 2326 -class IncrementalCompactFactor(SubDomain):
2327
2328 - def __init__(self,rawdata,domain=None,rmin=200):
2329 new_domain_variables, variables, records = rawdata[1:] 2330 SubDomain.__init__(self,variables,domain,new_domain_variables,check=True) 2331 # changing rmin between construction and _update calls on incremental 2332 # adtrees only has an effect where some other change in the tree ocurrs (i.e., 2333 # where there is new data). for lack of confusion, store it here. 2334 self._rmin = rmin 2335 var_index = {} 2336 numvals = [] 2337 for i, variable in enumerate(variables): 2338 var_index[variable] = i 2339 numvals.append(self._numvals[variable]) 2340 # list of number of values, ordered according to self._var_index 2341 self._varnumvals = numvals 2342 self._tree = _IncrementalADTree(numvals,records,self._rmin) 2343 self._var_index = var_index 2344 self._tree._verify(self._varnumvals, self._rmin)
2345 2346
2347 - def copy(self):
2348 cpy = copy(self) 2349 cpy._tree = cpy._tree.copy() 2350 return cpy
2351
2352 - def __getitem__(self,variables):
2353 return self.makeFactor(variables,check=True)
2354
2355 - def update(self, rawdata):
2356 records = rawdata[3] 2357 self._tree._update(self._varnumvals, records, self._rmin) 2358 self._tree._verify(self._varnumvals, self._rmin)
2359
2360 - def __str__(self):
2361 return "Variables: %s\nTree:\n%s\n" % ( 2362 sorted(self._variables), self._tree.str(self._varnumvals, self._rmin))
2363
2364 - def makeFactor(self,variables,check=False):
2365 """Return a marginal factor with C{variables} 2366 2367 Unless C{variables} is empty in which case return 2368 the sum of C{self}'s data 2369 @param variables: Variables to project onto 2370 @type variables: Iterable 2371 @param check: Whether to bother with an initial check that every member 2372 of C{variables} is in the L{CompactFactor}. (If this check is omitted 2373 and there I{is} an extra variable, then a L{Factor} with the wrong 2374 number of values will be created.) 2375 @type check: Boolean 2376 @return: The marginal factor 2377 @rtype: L{Factor} object 2378 @raise KeyError: if C{variables} contains a variable not contained in the 2379 L{CompactFactor}. 2380 """ 2381 variables_info = [ 2382 [self._var_index[var],self._numvals[var]] 2383 for var in variables] 2384 variables_info.sort() 2385 #size is the length of the list under each value of the corresponding variable 2386 size = 1 2387 for variable_info in reversed(variables_info): 2388 variable_info.append(size) 2389 size *= variable_info[1] 2390 return Factor(variables,self._tree._flatten(variables_info,0),SubDomain.copy(self))
2391
2392 - def makeCPT(self, child, parents, force_cpt=False, check=False, prior=0):
2393 """ 2394 @param prior: the Dirichlet prior parameter (the same parameter value 2395 is used for all instances!) Note there may be some problems with 2396 this method: a B{different} prior is used by the BDeu score. However, 2397 in practice, for parameter estimation, this prior method seems to be ok. 2398 I was lazy and it was simple to implement (cb). If prior is zero, then 2399 the parameters are the maximum likelihood estimation solutions. 2400 """ 2401 family = set(parents) | set([child]) 2402 f_child = self.makeFactor(family) 2403 return CPT(f_child+prior, child, cpt_check=check, cpt_force=force_cpt)
2404
2405 - def family_score(self, child, parents, precision=1.0):
2406 return self.makeCPT(child, parents, force_cpt=False, check=False).bdeu_score(precision)
2407
2408 - def size(self):
2409 """Return the number of nodes in the underlying ADTree 2410 2411 @return: The number of nodes in the tree 2412 @rtype: int 2413 """ 2414 return self._tree._count
2415 2416 # try: 2417 # import psyco 2418 2419 # psyco.bind(_merge_records) 2420 # psyco.bind(_distribute_records) 2421 # psyco.bind(_IncrementalADTree) 2422 # psyco.bind(IncrementalCompactFactor) 2423 # except: 2424 # pass 2425 2426 2427 # def create_view(self,variables): 2428 # if variables: 2429 # cols = ','.join(sorted(variables)) 2430 # sql = ('SELECT %s, sum(value) FROM %s GROUP BY %s' % 2431 # (cols,self.table,cols)) 2432 # else: 2433 # sql = 'SELECT sum(value) FROM %s' % self.table 2434 # self.cursor.execute('CREATE TEMP VIEW IF NOT EXISTS view_%s AS %s' % (cols,sql)) 2435 # self._cached.append(frozenset(variables)) 2436 2437 # problem computing each view separately is inefficient since it involves many passes 2438 # and we only need one! 2439 # each pass is easier though. 2440 # if we find a high-dimensional database with few rows, then no need to store 2441 # lower dimensional ones. 2442 2443 # given a frozenset of variables see if the exact table is stored 2444 # else for v in variables get set of all frozensets containing that variable 2445 # intersect them all, then use smallest of survivors 2446 2447 # from math import log 2448 # x = frozenset(x) 2449 # y = frozenset(y) 2450 # #if x & y: 2451 # # raise ValueError('%s and %s are not disjoint' % (x,y)) 2452 2453 # # construct 3 marginal tables 2454 # tables = (self.table,'temp0','temp0') 2455 # valcol = ('value','value0','value0') 2456 # for i, vs in enumerate(((x|y),x,y)): 2457 # vs = sorted(vs) 2458 # cols = ','.join([v + ' INT' for v in vs]+['value%d INT' % i]) 2459 # sql = 'CREATE TABLE temp%d (%s)' % (i,cols) 2460 # self.cursor.execute(sql) 2461 # vs = ','.join(vs) 2462 # sql = ('INSERT INTO temp%d ( %s, value%d ) SELECT %s, sum(%s) FROM %s GROUP BY %s' % 2463 # (i,vs,i,vs,valcol[i],tables[i],vs)) 2464 # #print sql 2465 # self.cursor.execute(sql) 2466 # # compute total count 2467 # self.cursor.execute('SELECT sum(value2) FROM temp2') 2468 # n = float(self.cursor.fetchone()[0]) 2469 2470 # # print 'xy' 2471 # # self.cursor.execute('SELECT * FROM temp0') 2472 # # print self.cursor.fetchall() 2473 2474 # # print 'x' 2475 # # self.cursor.execute('SELECT * FROM temp1') 2476 # # print self.cursor.fetchall() 2477 2478 # # print 'y' 2479 # # self.cursor.execute('SELECT * FROM temp2') 2480 # # print self.cursor.fetchall() 2481 2482 # # join tables 2483 # self.cursor.execute('SELECT value0, value1, value2 FROM (temp0 NATURAL JOIN temp1) NATURAL JOIN temp2') 2484 # mi = 0.0 2485 # for row in self.cursor: 2486 # [nxy,nx,ny] = row 2487 # #print [n,nxy,nx,ny] 2488 # mi += (nxy/n) * log((n*nxy)/float(nx*ny)) 2489 # for i in range(3): 2490 # self.cursor.execute('DROP TABLE temp%d' % i) 2491 # return mi 2492