Package gPy :: Module Models
[hide private]
[frames] | no frames]

Source Code for Module gPy.Models

   1  """ 
   2  Factored representations of probability distributions and probability models 
   3   
   4  @var _version: Version of this module 
   5  @type _version: String 
   6  """ 
   7   
   8  from Variables import SubDomain 
   9  from Hypergraphs import Hypergraph, ReducedHypergraph, SimpleHypergraph, ReducedJoinForest 
  10  from Graphs import ADG, DiGraph, DiForest, MutilatedADG, EssentialGraph 
  11  from IO import GraphCanvas 
  12  from Utils import emptyset, pretty_str_set, member 
  13  from Parameters import Factor, CPT 
  14  from Data import CompactFactor 
  15  import Tkinter, operator 
  16   
  17  _version = '$Id: Models.py,v 1.14 2008/10/07 09:11:55 jc Exp $' 
18 19 -class _AbsM(SubDomain):
20 """Abstract model class 21 22 Defined by a hypergraph and specification of values for the variables. 23 Represents the set of all probability distributions which have a factored representation whose 24 hypergraph is the specifed one. 25 26 Since the hypergraph is not required to be simple it is not useful to actually create objects 27 of this class, hence it is an abstract class. 28 The most general model class is that of log-linear models (LLM objects, not yet implemented) where the 29 associated hypergraph is required to be simple. 30 31 @ivar _hypergraph: The hypergraph associated with the model. There is exactly 32 one hyperedge for each factor. The hyperedge for a given factor is just the set of that 33 factor's variables (implemented as a frozenset). If the hypergraph is non-simple, then hyperedges 34 can be repeated. 35 @type _hypergraph: L{Hypergraph} object 36 @ivar _factors: A mapping from each distinct hyperedge in the associated hypergraph to a list of 37 factors having the variables of the hyperedge. 38 @type _factors: Dictionary 39 """ 40
41 - def __init__(self,factors=(),domain=None,new_domain_variables=None, 42 must_be_new=False,check=False):
43 """Initialise a hierarchical model 44 45 Each factor has its domain expanded to be the domain of the model, if necessary. 46 @param factors: The factors in the hierarchical model. 47 (Alternatively an existing object of class L{FR} (or its subclasses), in 48 which case C{self} has identical attributes.) 49 @type factors: Sequence, each element of which is a L{Parameters.Factor} or 50 L{Variables.SubDomain} object. (Alternatively an object of class L{FR} 51 or one of its subclasses.) 52 @param domain: A domain for the model. 53 If None and all C{factors} have the same domain then this domain is used. 54 If None and all C{factors} do not have the same domain then the internal default domain is used. 55 @type domain: L{Variables.Domain} or None 56 @param new_domain_variables: A dictionary containing a mapping from any new 57 variables to their values. C{domain} is updated with these values 58 @type new_domain_variables: Dict or None 59 @param must_be_new: Whether domain variables in C{new_domain_variables} have 60 to be new 61 @type must_be_new: Boolean 62 @param check: Whether to check that all variables exist in C{domain} 63 @type check: Boolean 64 @raise VariableError: If a variable in C{new_domain_variables} 65 already exists with values different from 66 its values in C{new_domain_variables}; 67 Or if C{must_be_new} is set and the variable already exists. 68 Or if C{check} is set and a variable in C{variables} is not in the domain 69 """ 70 if isinstance(factors,FR): 71 self.__dict__ = factors.__dict__ 72 return 73 self._factors = {} 74 hypergraph = Hypergraph() 75 variables = set() 76 for factor in factors: 77 hyperedge = factor.variables() 78 variables.update(hyperedge) 79 hypergraph.add_hyperedge(hyperedge) 80 self._add_factor_to_dict(hyperedge,factor) 81 if domain is None and len(factors) > 0: 82 first_factor_domain = factors[0]._domain 83 for factor in factors[1:]: 84 if factor._domain is not first_factor_domain: 85 break 86 else: 87 domain = factors[0] 88 SubDomain.__init__(self,variables,domain,new_domain_variables,must_be_new,check) 89 for factor in factors: 90 self.common_domain(factor) 91 self._hypergraph = hypergraph
92
93 - def __getitem__(self,hyperedge):
94 """Return the factor(s) (not a copy) whose variables are C{hyperedge} 95 96 @param hyperedge: The variables of the sought factor 97 @type hyperedge: Iterable 98 @return: The factor(s) 99 @rtype: A L{Parameters.Factor} object (for simple) or a list of such objects (non-simple) 100 @raise KeyError: If there is no factor with these variables 101 """ 102 return self._factors[frozenset(hyperedge)]
103
104 - def __iter__(self):
105 """Return an iterator over the factors in the model 106 107 To allow C{for factor in model: ...} constructions. 108 @return: An iterator over the factors in the model 109 @rtype: Iterator 110 """ 111 fs = [] 112 for factorlist in self._factors.values(): 113 fs.extend(factorlist) 114 return iter(fs)
115
116 - def __len__(self):
117 """Return the number of factors in the model 118 119 @return: The number of factors in the model 120 @rtype: Int 121 """ 122 return len(self._hypergraph)
123
124 - def __repr__(self):
125 return 'FR(%s,None,%s)' % (self._factors.values(), 126 self._domain)
127
128 - def __setitem__(self,hyperedge,factor):
129 """Set the factor whose variables are C{hyperedge} to factor 130 131 @param hyperedge: The variables of the factor 132 @type hyperedge: Iterable 133 @param factor: Factor which will replace existing factor(s) 134 @type factor: L{Parameters.Factor} 135 @raise KeyError: If there is no factor with these variables 136 """ 137 hyperedge = frozenset(hyperedge) 138 if hyperedge != factor.variables(): 139 raise ValueError('Can only replace factors with other factors with the same variables: %s, %s' % (hyperedge, factor.variables())) 140 self._put_factor_to_dict(hyperedge,factor)
141 142
143 - def __str__(self):
144 """Print each factor separately 145 146 Use lexicographical ordering on factor variables 147 148 @return: Pretty representation of a L{FR} 149 @rtype: String 150 """ 151 out = '' 152 tmp = [] 153 for hyperedge,factor in self.items(): 154 l = list(hyperedge) 155 l.sort() 156 tmp.append((l,factor)) 157 tmp.sort() 158 for l, factor in tmp: 159 out += str(factor) 160 out += '\n' 161 return out
162
163 - def __div__(self,other):
164 """Return the result of dividing a hierarchical model by a scalar 165 166 A randomly chosen factor is divided by the scalar 167 168 The returned value shares the same domain as C{self}. To avoid this make a deep copy and do 169 in-place division. 170 171 @param other: The scalar 172 @type other: float or int 173 """ 174 return self.copy().__idiv__(other)
175 176
177 - def __idiv__(self,other):
178 """Divide a hierarchical model by a scalar 179 180 A randomly chosen factor is divided by the scalar 181 182 @param other: The scalar 183 @type other: float or int 184 """ 185 for factor in self: 186 factor /= other 187 break 188 return self
189 190
191 - def __imul__(self,other):
192 """Multiply a hierarchical model by a factor, scalar or another L{FR} 193 194 @param other: The factor or L{FR} object to be multiplied in 195 @type other: L{Factor} of L{FR} object 196 """ 197 if isinstance(other,Factor): 198 self._add_factor(other) 199 elif isinstance(other,FR): 200 for factor in other: 201 self._add_factor(factor) 202 else: 203 for factor in self: 204 factor *= other 205 break 206 return self
207
208 - def __mul__(self,other):
209 """Return the result of multiplying a hierarchical model by a factor, scalar or another L{FR} 210 211 The returned value shares the same domain as C{self}. To avoid this make a deep copy and do 212 in-place multiplication. 213 214 @param other: The factor or L{FR} object to be multiplied in 215 @type other: L{Factor} of L{FR} object 216 """ 217 return self.copy().__imul__(other)
218
219 - def add_ident(self,hyperedge):
220 """Add an ident factor for a hyperedge to a hierarchical model 221 222 Typically C{hyperedge} will come from a L{Hypergraph} model. 223 224 An ident factor maps all instantiations to 1.0 225 @param hyperedge: The variables for the factor 226 @type hyperedge: Iterable 227 """ 228 hyperedge = frozenset(hyperedge) 229 self._hypergraph.add_hyperedge(hyperedge) 230 self._add_factor_to_dict(hyperedge,Factor(hyperedge)) 231 self._variables = frozenset(self._hypergraph.vertices())
232
233 - def condition(self,condition,keep_class=False):
234 """Alter a distribution by effecting the restriction on 235 variables given by C{condition} 236 237 This alters the model's domain. Make a copy with C{copy_domain=True} 238 if the original domain will be needed 239 @param condition: Dictionary of the form {var1:values1,var2:values2..} 240 Each value of this dictionary must be an iterable 241 @type condition: Dict 242 @return: The conditioned model 243 @param keep_class: TODO 244 @type keep_class: Boolean 245 @rtype: Same as C{self} 246 @raise KeyError: If a variable is used that is not in the model 247 @raise ValueError: If a value is used that is not a possible value of 248 the variable it is attached to 249 """ 250 self._check_condition(condition) 251 # change data in each factor 252 for factor in self: 253 factor.data_restrict(condition,keep_class) 254 # change stored values 255 self.change_domain_variables(condition) 256 return self
257
258 - def copy(self,copy_domain=False):
259 """Return a deep copy of a hierarchical model 260 261 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy 262 shares C{self}'s domain 263 @type copy_domain: Boolean 264 @return: A copy of C{self} 265 @rtype: Same type as C{self} 266 """ 267 cp = SubDomain.copy(self,copy_domain) 268 cp.__class__ = self.__class__ 269 cp._hypergraph = self._hypergraph.copy() 270 factors = {} 271 for hyperedge, factorlist in self._factors.items(): 272 nw = [] 273 for factor in factorlist: 274 fc = factor.copy() 275 cp.common_domain(fc) 276 nw.append(fc) 277 factors[hyperedge] = nw 278 cp._factors = factors 279 return cp
280
281 - def cpt(self,child,parents=()):
282 """Return a conditional probability table for specified child and parents 283 284 Return the probability distribution of C{child} conditional on C{parents} 285 286 @param child: Child of the CPT 287 @type child: String 288 @param parents: Parents of C{child} in the CPT 289 @type parents: Sequence 290 @return: The specified CPT 291 @rtype: L{Parameters.CPT} 292 """ 293 cp = self.copy() 294 vs = tuple(parents) + ((child),) 295 cp.marginal(vs) 296 joint = 1 297 for f in cp: 298 joint *= f 299 return CPT(joint,child,cpt_force=True)
300 301
302 - def eliminate_variable(self,variable,trace=False):
303 """Alter a factored representation by summing out a variable 304 305 Removes one or more factors. 306 @param variable: The variable to eliminate 307 @type variable: String 308 @raise KeyError: If C{variable} is not in the model 309 """ 310 hyperedges = self._hypergraph.star(variable) 311 if variable in self._instd: 312 for hyperedge in hyperedges: 313 # drop the variable from the factor, #hmm unnecessary multiplication 314 nf = self.factor(hyperedge).drop_variable(variable) 315 # remove old factor 316 self.remove(hyperedge) 317 # add in new one 318 self *= nf 319 if trace: 320 return None, None, hyperedges 321 else: 322 return 323 prod_factor = 1 324 for hyperedge in hyperedges: 325 prod_factor *= self.factor(hyperedge) 326 self.remove(hyperedge) 327 message = prod_factor.sumout([variable]) 328 self *= message 329 if trace: 330 return prod_factor.variables(), message.variables(), hyperedges
331
332 - def factor(self,hyperedge):
333 """Return the factor produced by multiplying all factors with variables C{hyperedge} 334 335 @param hyperedge: Set of variables 336 @type hyperedge: Iterable 337 @return: Product of all factors with variables C{hyperedge} 338 @rtype: L{Parameters.Factor} object 339 @raise KeyError: if no factor has C{hyperedge} as variables. 340 """ 341 factors = self._factors[frozenset(hyperedge)] 342 prod = factors[0] 343 for f in factors[1:]: 344 prod *= f 345 return prod
346
347 - def factors(self):
348 """ 349 Return a list of the factors in the model 350 351 @return: A list of the factors in the model 352 @rtype: List 353 """ 354 fs = [] 355 for factorlist in self._factors.values(): 356 fs.extend(factorlist) 357 return fs
358
359 - def factors_containing_variable(self,variable):
360 """Return a list of factors containing C{variable} 361 362 If non-simple, then a list of lists of factors is returned 363 364 @param variable: A variable 365 @type variable: Immutable (usually string) 366 @return: A list of factors containing C{variable} 367 @rtype: List 368 """ 369 fs = [] 370 for hyperedge in self._hypergraph[variable]: 371 fs.extend(self._factors[hyperedge]) 372 return fs
373 374
375 - def items(self):
376 """Return sequence of C{factor.variables(),factor} pairs 377 for each factor in the model. 378 379 @return: Sequence of C{factor.variables(),factor} pairs 380 for each factor in the model 381 @rtype: List 382 """ 383 itms = [] 384 for hyperedge, factorlist in self._factors.items(): 385 for factor in factorlist: 386 itms.append((hyperedge,factor)) 387 return itms
388
389 - def gui_display(self,parent,colours=None):
390 """Display a GUI widget for displaying a model 391 392 @param parent: A widget into which the GUI is placed. 393 @type parent: Some suitable Tk object. 394 @param colours: Mapping from hyperedges to colours 395 @type colours: Dictionary 396 """ 397 if colours is None: 398 colours = {} 399 gui = Tkinter.Frame(parent) 400 gui.pack() 401 top = Tkinter.Frame(gui) 402 top.pack() 403 fgs = [] 404 for hyperedge, factor in self.items(): 405 fgs.append(factor.gui_main(top,edit=False,bg=colours.get(hyperedge,'grey'))) 406 fgs.append(Tkinter.Label(top,text='*')) 407 fgs.pop() 408 for widget in fgs: 409 widget.pack(side=Tkinter.LEFT) 410 bottom = Tkinter.Frame(gui) 411 bottom.pack() 412 for (txt,cmd) in [('Done',gui.destroy), 413 ('Quit', parent.destroy)]: 414 button = Tkinter.Button(bottom,text=txt,command=cmd) 415 button.bind('<Return>', lambda event: cmd()) 416 button.pack(side=Tkinter.LEFT)
417
418 - def hypergraph(self):
419 """ 420 Return (a copy of) the hypergraph associated with the model 421 422 @return: (A copy of) the hypergraph associated with the model 423 @rtype: L{Hypergraph} 424 """ 425 return self._hypergraph.copy()
426
427 - def inc_from_rawdata(self,rawdata):
428 """ 429 Increment C{self} with counts directly from C{rawdata} 430 431 OK for parameter fitting not for structure learning 432 @param rawdata: A tuple like that returned by L{IO.read_csv}. 433 @type rawdata: Tuple 434 @raise IndexError: If C{self} has a variable missing from C{rawdata} 435 """ 436 variables,records = rawdata[2:] 437 factor_variables_info = [] 438 factors = self.factors() 439 for factor in factors: 440 factor_variables_info.append(factor.get_variables_info(variables)) 441 for record in records: 442 for i, factor in enumerate(factors): 443 factor.inc_from_record(factor_variables_info[i],record)
444
445 - def interaction_graph(self):
446 """Return the interaction graph for a model 447 448 The interaction graph contains an edge for any pair of variables 449 which are members of a common factor. 450 @return: The interaction graph 451 @rtype: L{Graphs.UGraph} 452 """ 453 return self._hypergraph.two_section()
454
455 - def is_simple(self):
456 """Whether the model is simple 457 458 @return: Whether the model is simple 459 @rtype: Boolean 460 """ 461 return self._hypergraph.is_simple()
462
463 - def make_decomposable(self,elimination_ordering=None):
464 """Return a decomposable model using an elimination ordering 465 466 If no C{elimination_ordering} is given 467 L{Hypergraphs.ReducedHypergraph.maximum_cardinality_search} is used to 468 provide one. 469 470 Returned object shares many attributes with C{self}. Use a copy of self, 471 if it still needed as an independent object. 472 473 @param elimination_ordering: Order in which to eliminate variables 474 @type elimination_ordering: Sequence 475 @return: A decomposable model 476 @rtype: L{DFR} 477 """ 478 479 dg, destination = self.hypergraph().make_decomposable2(elimination_ordering) 480 fs = {} 481 for new_hyperedge in dg: 482 fs[new_hyperedge] = Factor(new_hyperedge,domain=self) 483 for old_hyperedge, new_hyperedge in destination.items(): 484 fs[new_hyperedge] *= self.factor(old_hyperedge) 485 dm = DFR() 486 dm.__dict__ = self.__dict__ 487 dm._hypergraph = dg 488 dm._factors = fs 489 return dm
490 491 # def make_reduced(self): 492 # """Return a reduced model by absorbing redundant factors 493 494 # Returned object shares many attributes with C{self}. Use a copy of self, 495 # if it still needed as an independent object. 496 497 # @return: A reduced model 498 # @rtype: L{RFR} 499 # """ 500 # destinations = self._hypergraph.make_reduced() 501 # for redund, supersets in destinations.items(): 502 # superset = member(supersets) 503 # self._factors[superset] *= self._factors[redund] 504 # del self._factors[redund] 505 # rm = RFR() 506 # rm.__dict__ = self.__dict__ 507 # try: 508 # del rm._adg 509 # except NameError: 510 # pass 511 # return rm 512 513
514 - def marginal(self,variables):
515 """Alter a model to represent the marginal distibution 516 on C{variables} 517 518 Marginal only represented up to normalisation 519 520 If C{self} is a BN and C{variables} is an ancestral set then C{self} remains a BN otherwise 521 it becomes a general hierarchical model. 522 """ 523 self.sumout(self.variables().difference(variables))
524
525 - def marginal_factor(self,variables):
526 """Return the marginal distibution 527 on C{variables} as a L{Parameters.Factor} 528 529 Marginal is properly normalised 530 531 """ 532 cp = self.copy() 533 cp.sumout(self.variables().difference(variables)) 534 # used to be ..., why did Charles change 535 # return cp[variables].normalised() 536 return reduce(operator.mul, cp).normalised()
537
538 - def marginalise_away(self,variables,naive=True):
539 """Alter a factored representation by summing out variables 540 541 TODO: If we just want a marginal, no need to bother altering self 542 543 @param variables: Ordered variables to eliminate 544 @type variables: Sequence type 545 @param naive: If C{True} the variables are summed out in the order given 546 by C{variables}. If C{False} conditional independence and the number of factors 547 a variable is in is used to construct a more intelligent order. 548 @type naive: Boolean 549 @return: The altered C{self} 550 @rtype: Class of C{self} 551 """ 552 if not naive: 553 # remove factors all of whose variables are conditionally 554 # independent of those remaining 555 insted = self._instd 556 remaining = self.variables() - set(variables) 557 hypergraph = self._hypergraph 558 reachable = set(hypergraph.reachable(remaining-insted,insted)) 559 reachable.update(remaining) 560 for hyperedge in hypergraph.hyperedges(): 561 if not (hyperedge & reachable): 562 self.remove(hyperedge) 563 564 # only concerned with surviving variables 565 variables = self.variables().intersection(variables) 566 567 # choose fixed ordering depending on number of hyperedges a variable is in 568 def var_sort(v): 569 return self.num_factors_containing_variable(v)
570 variables = sorted(variables,key=var_sort) 571 572 for variable in variables: 573 self.eliminate_variable(variable) 574 return self
575
576 - def makeDN(self,allow_dummies=False):
577 """Make a dependency network where the CPT for each variable 578 is its distribution conditional on its Markov blanket 579 """ 580 return DN([self.markov_blanket_cpt(v,allow_dummies) for v in self._variables])
581
582 - def markov_blanket(self,variable):
583 """Return the Markov blanket for C{variable} 584 585 @param variable: Variable in the model 586 @type variable: Immutable 587 @return: Markov blanket for C{variable} 588 @rtype: Set 589 @raise KeyError: If C{variable} is not in the model 590 """ 591 return self._hypergraph.neighbours(variable)
592
593 - def markov_blanket_cpt(self,variable,allow_dummies=False):
594 """Create a CPT for the distribution of C{variable} 595 conditional on its Markov blanket 596 597 @param variable: variable for which the conditional distribution is required 598 @type variable: Immutable 599 """ 600 factor = 1 601 for hyperedge in self._hypergraph.star(variable): 602 factor *= self.factor(hyperedge) 603 return CPT(factor,variable,cpt_force=True,allow_dummies=allow_dummies)
604
605 - def num_factors_containing_variable(self,variable):
606 """Return the number of distincet factors containing a given variable 607 """ 608 return self._hypergraph.star_size(variable)
609
610 - def remove(self,hyperedge):
611 """Remove a factor or factors from a hierarchical model 612 613 All factors with variables C{hyperedge} will be removed 614 615 @param hyperedge: The variables of the factor 616 @type hyperedge: Iterable 617 @raise KeyError: If no factor with these variables exists 618 """ 619 hyperedge = frozenset(hyperedge) 620 del self._factors[hyperedge] 621 self._hypergraph.remove_hyperedge(hyperedge) 622 self._variables = frozenset(self._hypergraph.vertices())
623
624 - def red(self):
625 """Reduce the model, returning any distinct redundant hyperedges 626 627 Any factor all of whose variables are contained in another is removed 628 @return: Redundant hyperedges 629 @rtype: List 630 """ 631 self.simplify() 632 reds = self._hypergraph.redundant_hyperedges() 633 for hyperedge, supersets in reds.items(): 634 absorbing_factor = self.some_factor(supersets.pop()) 635 absorbing_factor *= self.factor(hyperedge) 636 self.remove(hyperedge) 637 reds = reds.keys() 638 if emptyset in self._hypergraph: 639 reds.append(emptyset) 640 self.remove_hyperedge(emptyset) 641 return reds
642
643 - def reduced(self):
644 """Test whether a model is reduced 645 646 @return: Whether a model is reduced 647 @rtype: Boolean 648 """ 649 return self._hypergraph.is_reduced()
650
651 - def simplify(self):
652 """Simplify the model 653 654 so that there are no factors with the same variable sets 655 """ 656 for hyperedge, factorlist in self._factors.items(): 657 nfl = factorlist[0] 658 for f in factorlist[1:]: 659 nfl *= f 660 self._factors[hyperedge] = [nfl]
661
662 - def some_factor(self,hyperedge):
663 """Return an arbitrary factor with variables C{hyperedge} 664 665 @param hyperedge: Variable set 666 @type hyperedge: Iterable 667 @return: Factor 668 @rtype: L{Parameters.Factor} object 669 @raise KeyError: If no factor with C{hyperedge} as varaibles exists 670 """ 671 self._factors[frozenset(hyperedge)][0]
672
673 - def sumout(self,variables):
674 """Sum out (marginalise away) variables using maximum cardinality 675 676 C{variables} may be altered 677 678 @param variables: Variables to sum out 679 @type variables: Iterable 680 @return: The marginal model 681 @rtype: Same as C{self} 682 """ 683 if not variables: 684 return 685 min = len(self._hypergraph) 686 for variable in variables: 687 n = self._hypergraph.star_size(variable) 688 if n == 1: 689 best = variable 690 break 691 else: 692 if n <= min: 693 min = n 694 best = variable 695 self.eliminate_variable(best) 696 if not isinstance(variables,set): 697 variables = set(variables) 698 variables.remove(best) 699 self.sumout(variables) 700 return self
701 702 variable_elimination = marginalise_away 703
704 - def variable_elimination_trace(self,variables):
705 """As L{variable_elimination} but also return the corresponding 706 L{Graphs.DiForest} object 707 708 If all variables are summed out then the relevant scalar is returned 709 @param variables: Ordered variables to eliminate 710 @type variables: Sequence type 711 @return: C{None} usually. The relevant scalar if all variables are 712 summed out 713 @rtype: C{None} or C{Float} 714 """ 715 junction_forest = DiForest() 716 message_produced = {} 717 for variable in variables: 718 cluster, new_message, messages_used = self.eliminate_variable(variable,True) 719 junction_forest.add_vertex(cluster) 720 for old_cluster, old_message in message_produced.items(): 721 if old_message in messages_used: 722 junction_forest.add_arrow(old_cluster,cluster) 723 del message_produced[old_cluster] 724 message_produced[cluster] = new_message 725 return junction_forest
726 727
728 - def z(self):
729 """Return the sum of values associated with each full joint instantiation 730 731 This is the partition function. Result is generally not 1.0 since models may 732 represent distributions only up to normalisation. An expensive operation. 733 @return: The sum of values associated with each full joint instantiation 734 @rtype: Float 735 @todo: Use sensible ordering of variables to eliminate 736 """ 737 tmp = self.copy() 738 tmp.variable_elimination(tmp._variables) 739 return tmp.factor(emptyset).z()
740
741 - def zero(self):
742 """Set all values in all factors to zero 743 """ 744 for factor in self: 745 factor.zero()
746
747 - def _add_factor(self,factor):
748 """Multiply a hierarchical model by a factor 749 750 @param factor: The factor to be multiplied in 751 @type factor: L{Factor} 752 """ 753 self.common_domain(factor) 754 hyperedge = factor.variables() 755 self._variables |= hyperedge 756 self._hypergraph.add_hyperedge(hyperedge) 757 self._add_factor_to_dict(hyperedge,factor) 758 #why is this done twice!? 759 self._variables = frozenset(self._hypergraph.vertices())
760
761 - def _add_factor_to_dict(self,hyperedge,factor):
762 """Add a factor to the _factors dictionary 763 764 @param hyperedge: The hyperedge for the factor 765 @type hyperedge: Frozenset 766 @param factor: The factor to be added 767 @type factor: L{Factor} 768 """ 769 try: 770 self._factors[hyperedge].append(factor) 771 except KeyError: 772 self._factors[hyperedge] = [factor]
773
774 - def _check_condition(self,condition):
775 for variable, values in condition.items(): 776 if variable not in self._variables: 777 raise KeyError('%s not a variable of this model' % variable) 778 for value in values: 779 if value not in self._domain[variable]: 780 raise ValueError( 781 "%s has values %s. '%s' is not one of them" % 782 (variable, tuple(self._domain[variable]), value))
783
784 - def _put_factor_to_dict(self,hyperedge,factor):
785 """Put a factor in the _factors dictionary 786 787 Overwriting any previous entry 788 789 @param hyperedge: The hyperedge for the factor 790 @type hyperedge: Frozenset 791 @param factor: The factor to be put 792 @type factor: L{Factor} 793 """ 794 self._factors[hyperedge] = [factor]
795
796 -class FR(_AbsM):
797 """Factored representations of probability distributions 798 799 Probability distributions represented by a set of factors. 800 The distribution is equal to the product of these factors (up to 801 normalisation). 802 Each factor is a L{Parameters.Factor}. 803 """ 804 pass
805
806 -class SFR(FR):
807 """Factored representations of probability distributions (associated hypergraph is simple) 808 809 Probability distributions represented by a set of factors. 810 The distribution is equal to the product of these factors. 811 Each factor is a L{Parameters.Factor}. 812 813 No two factors have the same variables. 814 If factors were (A,B), (A,B), (B,C) then would NOT be simple. 815 816 @ivar _hypergraph: The hypergraph associated with the model. There is exactly 817 one hyperedge for each factor. The hyperedge for a given factor is just the set of that 818 factor's variables (implemented as a frozenset). 819 @type _hypergraph: L{SimpleHypergraph} object 820 @ivar _factors: A mapping from each hyperedge in the associated hypergraph to its associated factor, 821 who will have the hyperedge as its variables. 822 @type _factors: Dictionary 823 """
824 - def __init__(self,factors=(),domain=None,new_domain_variables=None, 825 must_be_new=False,check=False):
826 FR.__init__(self,factors,domain,new_domain_variables,must_be_new,check) 827 self._hypergraph = SimpleHypergraph(self._hypergraph)
828
829 - def __iter__(self):
830 """Return an iterator over the factors in the model 831 832 To allow C{for factor in model: ...} constructions. 833 @return: An iterator over the factors in the model 834 @rtype: Iterator 835 """ 836 return self._factors.itervalues()
837
838 - def _add_factor(self,factor):
839 """Multiply a simple hierarchical model by a factor 840 841 If a factor with the same variables already exists, 842 then this existing factor is multiplied by C{factor} to maintain 843 simplicity. 844 845 @param factor: The factor to be multiplied in 846 @type factor: L{Factor} 847 """ 848 try: 849 self._factors[factor.variables()] *= factor 850 except KeyError: 851 FR._add_factor(self,factor)
852 853
854 - def _add_factor_to_dict(self,hyperedge,factor):
855 """Add a factor to the _factors dictionary 856 857 @param hyperedge: The hyperedge for the factor 858 @type hyperedge: Frozenset 859 @param factor: The factor to be added 860 @type factor: L{Factor} 861 """ 862 self._factors[hyperedge] = factor
863 864 _put_factor_to_dict = _add_factor_to_dict 865 866
867 - def copy(self,copy_domain=False):
868 """Return a deep copy of a hierarchical model 869 870 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy 871 shares C{self}'s domain 872 @type copy_domain: Boolean 873 @return: A copy of C{self} 874 @rtype: Same type as C{self} 875 """ 876 cp = SubDomain.copy(self,copy_domain) 877 cp.__class__ = self.__class__ 878 cp._hypergraph = self._hypergraph.copy() 879 factors = {} 880 for hyperedge, factor in self._factors.items(): 881 fc = factor.copy() 882 cp.common_domain(fc) 883 factors[hyperedge] = fc 884 cp._factors = factors 885 return cp
886
887 - def factor(self,hyperedge):
888 return self._factors[frozenset(hyperedge)]
889
890 - def factors(self):
891 return self._factors.values()
892
893 - def factors_containing_variable(self,variable):
894 return [self._factors[hyperedge] for 895 hyperedge in self._hypergraph[variable]]
896
897 - def is_simple(self):
898 return True
899
900 - def items(self):
901 return self._factors.items()
902
903 - def simplify(self):
904 pass
905
906 - def some_factor(self,hyperedge):
907 return self._factors[frozenset(hyperedge)]
908
909 -class GFR(FR):
910 """Factored representations of probability distributions (associated hypergraph is graphical) 911 912 Factors are defined exactly on the cliques of its interaction graph. 913 If factors were (A,B), (A,C), (B,C) then would NOT be graphical. 914 915 @todo: Actually implement methods specific to this class. At present L{GFR} objects 916 behave identically to L{FR} objects so this class only useful as a label I{claiming} 917 that a hierarchical model is graphical. 918 """ 919 pass
920
921 -class RFR(SFR):
922 """Factored representations of probability distributions (associated hypergraph is reduced) 923 924 A representation with no redundant factors. 925 If factors were (A), (A,B) then would NOT be reduced. 926 927 @todo: Actually implement methods specific to this class. At present L{RFR} objects 928 behave identically to L{SFR} objects so this class only useful as a label I{claiming} 929 that a hierarchical model is reduced. 930 """ 931
932 - def __init__(self,hm=(),check=True,modify=False):
933 """Construct RFR from existing hierarchical model""" 934 if not isinstance(hm,FR): 935 hm = SFR(hm) 936 if isinstance(hm,RFR): 937 check = False 938 modify = False 939 rhg = ReducedHypergraph(hm._hypergraph,check,modify,trace=True) 940 if modify: 941 for redund, supersets in rhg.trace.items(): 942 superset = member(supersets) 943 hm._factors[superset] *= hm._factors[redund] 944 del hm._factors[redund] 945 del rhg.trace 946 if isinstance(hm,BN): 947 del hm._adg 948 self.__dict__ = hm.__dict__
949
950 - def _add_factor(self,factor):
951 """Multiply a reduced factored representation by a factor 952 953 If a factor containing all C{factor}'s variables exists then 954 C{factor} is multiplied into it. 955 956 @param factor: The factor to be multiplied in 957 @type factor: L{Factor} 958 @todo: implement properly 959 """ 960 hyperedge = factor.variables() 961 smallerhes = [] 962 for he in self._hypergraph: 963 if not smallerhes and he >= hyperedge: 964 self._factors[he] *= factor 965 return 966 elif he <= hyperedge: 967 smallerhes.append(he) 968 SFR._add_factor(self,factor) 969 for he in smallerhes: 970 self._factor[hyperedge] *= self._factor[he] 971 self.remove(he)
972
973 - def ipf(self,marginals,epsilon=0.001):
974 """Iterative proportional fitting 975 976 Iteratively alters the marginals of C{self} until 977 they (almost) equal those supplied by C{marginals}. 978 The existing factors in C{self} act only as a starting 979 point for the iterative algorithm. 980 Iteration stops once there is no marginal probability in C{self} 981 differing by more than C{epsilon} from the one supplied by C{marginals}. 982 983 Each factor in C{self} must have exactly one corresponding marginal 984 in C{marginals}. 985 986 Typically C{marginals} will come from an empirical distribution, ie 987 they will just be normalised counts from some data set. 988 989 @param marginals: Marginals to fit 990 @type marginals: Dictionary mapping hyperedges to marginals 991 @param epsilon: Convergence criterion 992 @type epsilon: Float 993 """ 994 converged = False 995 while not converged: 996 converged = True 997 for hyperedge, factor in self._factors.items(): 998 current_marginal = self.marginal_factor(hyperedge) 999 desired_marginal = marginals[hyperedge] 1000 factor *= desired_marginal / current_marginal 1001 if converged and current_marginal.differ(desired_marginal,epsilon): 1002 converged = False
1003
1004 1005 # jf = JFR(self.copy(),modify=True) 1006 # store = [] 1007 # for old, new in jf.trace.items(): 1008 # store.append((old, new, jf.perfect_sequence(new)) 1009 # converged = False 1010 # while not converged: 1011 # converged = True # assumed converged until proven otherwise 1012 # for hyperedge, root, perfect_sequence in store: 1013 # correction = (marginals[hyperedge] / 1014 # jf.marginal(hyperedge,root,perfect_sequence)) 1015 # jf[root] *= correction 1016 # if converged: 1017 # for x in correction.data(): 1018 # if abs(x) > epsilon: 1019 # converged = False 1020 # break 1021 # for hyperedge, root, perfect_sequence in store: 1022 # self._factors[hyperedge] = jf[root].copy().marginalise_onto(hyperedge) 1023 1024 1025 1026 -class DFR(GFR):
1027 """Factored representations of probability distributions (associated hypergraph is decomposable) 1028 1029 Interaction graph is decomposable (triangulated) 1030 1031 @note: There are no methods specific to this class. At present L{DFR} objects 1032 behave identically to L{FR} objects so this class is only useful as a label I{claiming} 1033 that a hierarchical model is decomposable. However there are methods for L{JFR} objects 1034 which are just (reduced) decomposable models with a particular join forest specified. 1035 """ 1036 pass
1037
1038 1039 -class RDFR(RFR,DFR):
1040 """Factored representations of probability distributions (associated hypergraph is decomposable and reduced) 1041 1042 Interaction graph is decomposable (triangulated) 1043 1044 @note: There are no methods specific to this class. At present L{RDFR} objects 1045 behave identically to L{FR} objects so this class is only useful as a label I{claiming} 1046 that a hierarchical model is decomposable and reduced. However there are methods for L{JFR} objects 1047 which are just reduced decomposable models with a particular join forest specified. 1048 """ 1049 pass
1050
1051 1052 -class JFR(RDFR):
1053 """Join forest representations of probability distributions 1054 1055 A reduced, decomposable model where the factors are clique potentials and 1056 separators, related by a join forest. 1057 1058 @ivar _hypergraph: The join forest associated with the model. 1059 @type _hypergraph: L{Hypergraphs.ReducedJoinForest} 1060 @ivar _factors: Maps each clique (node of the join_forest) to its associated factor 1061 @type _factors: Dictionary 1062 @ivar _separators: Maps each separator (edge of the join_forest) to its associated factor 1063 @type _separators: Dictionary 1064 """ 1065
1066 - def __init__(self,hm=(),domain=None,new_domain_variables=None, 1067 must_be_new=False,check=False,modify=False,elimination_order=None):
1068 """Construct a join forest model from an existing hierarchical model (or factors) 1069 1070 If C{hm} is a hypergraph it will end up with identical attributes 1071 to C{self}, so typically a copy of an existing hypergraph is used. 1072 1073 @param hm: Hierachical model or sequence of factors 1074 @type hm: L{FR} or sequence 1075 @param domain: B{Only used if C{hm} is not an L{FR} object.} A domain for the model. 1076 If None the internal default domain is used. 1077 @type domain: L{Variables.Domain} or None 1078 @param new_domain_variables: B{Only used if C{hm} is not an L{FR} object.} A dictionary 1079 containing a mapping from any new 1080 variables to their values. C{domain} is updated with these values 1081 @type new_domain_variables: Dict or None 1082 @param must_be_new: B{Only used if C{hm} is not an L{FR} object.} Whether domain 1083 variables in C{new_domain_variables} have to be new 1084 @type must_be_new: Boolean 1085 @param check: B{Only used if C{hm} is not an L{FR} object.} Whether to check 1086 that all variables exist in C{domain} 1087 @type check: Boolean 1088 @param modify: Whether to modify C{hm} to make it decomposable. 1089 @type modify: Boolean 1090 @param elimination_order: If supplied 1091 and C{modify=True}, the elimination order to use to make the C{hm} 1092 decomposable. (If not supplied maximum cardinality search is used to generate an order.) 1093 @type elimination_order: Sequence 1094 @raise VariableError: If a variable in C{new_domain_variables} 1095 already exists with values different from 1096 its values in C{new_domain_variables}; 1097 Or if C{must_be_new} is set and the variable already exists. 1098 Or if C{check} is set and a variable in C{variables} is not in the domain 1099 @raise DecomposabilityError: If C{modify=False} and C{hm} is not decomposable. 1100 """ 1101 if not isinstance(hm,FR): 1102 hm = FR(hm,domain,new_domain_variables,must_be_new,check) 1103 join_forest = ReducedJoinForest(hm._hypergraph,modify,True,elimination_order) 1104 if modify: 1105 fs = {} 1106 for hyperedge in join_forest: 1107 fs[hyperedge] = Factor(hyperedge,domain=hm) 1108 for old_hyperedge, hyperedge in join_forest.trace.items(): 1109 fs[hyperedge] *= hm._factors[old_hyperedge] 1110 hm._factors = fs 1111 del join_forest.trace 1112 seps = {} 1113 for clique1, clique2 in join_forest._uforest.lines(): 1114 seps[frozenset([clique1,clique2])] = Factor(clique1 & clique2, 1115 domain=hm) 1116 hm._hypergraph = join_forest 1117 hm._separators = seps 1118 self.__dict__ = hm.__dict__
1119
1120 - def __str__(self):
1121 return 'Cliques:\n%sSeparators:\n%sJoin Forest:\n%s' % ( 1122 RFR.__str__(self), 1123 FR(self._separators.values()), 1124 self._hypergraph)
1125 1126 # def calibrate(self): 1127 # """Alter a JFR so that the factors associated with both cliques and 1128 # separators are the appropriate marginal distributions 1129 # """ 1130 # perfect_sequence = self._hypergraph.perfect_sequence() 1131 # self.send_messages(perfect_sequence[:]) 1132 # perfect_sequence.reverse() 1133 # self.send_messages(perfect_sequence) 1134
1135 - def calibrate(self):
1136 """Alter a JFR so that the factors associated with both cliques and 1137 separators are the appropriate marginal distributions. 1138 1139 In each tree in the forest, messages are first passed to a root and then 1140 out again. 1141 """ 1142 jf = self._hypergraph.join_forest() 1143 edges = jf.ordered_edges_towards_root() 1144 for (frm,to) in edges: 1145 self.send_message(frm,to) 1146 for (frm,to) in reversed(edges): 1147 self.send_message(to,frm)
1148 1149
1150 - def condition(self,condition):
1151 """Alter a JFR by effecting the restriction on 1152 variables given by C{condition} 1153 1154 This alters the model's domain. Make a copy with C{copy_domain=True} 1155 if the original domain will be needed 1156 @param condition: Dictionary of the form {var1:value1,var2:value2..} 1157 @type condition: Dict 1158 """ 1159 self._check_condition(condition) 1160 # change data in each factor 1161 for factor in self._factors.values(): 1162 factor.data_restrict(condition) 1163 # change data in each separator 1164 for sep in self._separators.values(): 1165 sep.data_restrict(condition) 1166 # change stored values 1167 self.change_domain_variables(condition) 1168 return self
1169
1170 - def copy(self,copy_domain=False):
1171 """Return a deep copy of a JFR 1172 1173 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy 1174 shares C{self}'s domain 1175 @type copy_domain: Boolean 1176 @return: A copy of C{self} 1177 @rtype: L{JFR} 1178 """ 1179 cp = FR.copy(self,copy_domain) 1180 seps = {} 1181 # bug corrected by Jon Ronson Tue Dec 12 17:40:33 GMT 2006 1182 for hyperedge, sep in self._separators.items(): 1183 sc = sep.copy() 1184 cp.common_domain(sc) 1185 seps[hyperedge] = sc 1186 cp._separators = seps 1187 return cp
1188
1189 - def gui_calibrate(self,parent):
1190 clique_win = Tkinter.Frame(parent) 1191 clique_win.pack() 1192 self._clique_disp(clique_win,self._factors) 1193 sep_win = Tkinter.Frame(parent) 1194 sep_win.pack() 1195 self._clique_disp(sep_win,self._separators) 1196 gc = GraphCanvas(self._hypergraph._uforest,parent,edit=False,pp_vertex=pretty_str_set, 1197 colour_user_actions=False) 1198 gc.pack() 1199 bottom_win = Tkinter.Frame(parent) 1200 bottom_win.pack() 1201 perfect_sequence = self._hypergraph.perfect_sequence() 1202 cpps = perfect_sequence[:] 1203 banned = set() 1204 updone=[False] 1205 def handler(self=self,cliques=cpps,banned=banned,clique_win=clique_win, 1206 sep_win=sep_win,updone=updone,gc=gc): 1207 if not cliques: 1208 if updone[0]: 1209 print 'All messages sent and received' 1210 return 1211 cliques[:] = perfect_sequence 1212 cliques.reverse() 1213 banned.clear() 1214 updone[0] = True 1215 print 'ROOT has received all messages' 1216 clique = cliques.pop() 1217 for nbr in self._hypergraph.clique_neighbours(clique,banned): 1218 self.send_message(clique,nbr) 1219 print 'Sent from %s to %s' % (clique,nbr) 1220 for clique2 in self._hypergraph: 1221 gc.vertex_config(clique2,fill='black') 1222 gc.vertex_config(clique,fill='blue') 1223 gc.vertex_config(nbr,fill='red') 1224 banned.add(clique) 1225 self._clique_disp(clique_win,self._factors) 1226 self._clique_disp(sep_win,self._separators)
1227 Tkinter.Button(bottom_win,text='Next',command=handler).pack()
1228
1229 - def marginal(self,variables,root=None,perfect_sequence=None):
1230 """Compute the marginal distribution of C{variables} B{assuming they 1231 are a subset of some hyperedge (ie clique)} 1232 1233 Just chooses an appropriate hyperedge as a root and sends messages to it. 1234 Typically used just after C{self} has been altered so that, at least, 1235 this root clique has the correct marginal. 1236 1237 Used in iterative proportional fitting. 1238 1239 @param variables: The variables for which a marginal is sought. 1240 @type variables: Iterable 1241 @param root: If supplied, a clique which is simply B{assumed} to contain 1242 the variables. If not supplied one is computed. 1243 @type root: Frozenset 1244 @param perfect_sequence: If supplied an ordering B{assumed} to be a perfect 1245 sequence with C{root} as its first element. 1246 @type perfect_sequence: Sequence 1247 @raise ValueError: If no clique contains C{variables} 1248 """ 1249 variables=frozenset(variables) 1250 if root is None: 1251 for v in variables: 1252 try: 1253 containers &= self._hypergraph.star(v) 1254 except NameError: 1255 containers = self._hypergraph.star(v) 1256 if not containers: 1257 raise ValueError("No clique contains %s") 1258 root = member(containers) 1259 if perfect_sequence is None: 1260 perfect_sequence = self.perfect_sequence(root) 1261 self.send_messages(perfect_sequence) 1262 return self._factors[root].copy().marginalise_onto(variables)
1263
1264 - def send_message(self,frm,to):
1265 """Send a message from a clique to a neighbouring clique 1266 1267 @param frm: The clique sending the message 1268 @type frm: Frozenset (hyperedge) 1269 @param to: The clique receiving the message 1270 @type to: Frozenset (hyperedge) 1271 @raise KeyError: If C{frm} and C{to} are not neighbours in the join 1272 forest 1273 """ 1274 frm_marginal = self._factors[frm].sumout(frm - to) 1275 edge = frozenset([frm,to]) 1276 self._factors[to] *= (frm_marginal/self._separators[edge]) 1277 self._separators[edge] = frm_marginal
1278
1279 - def send_messages(self,cliques,banned=None):
1280 """Send messages using a fixed ordering of cliques 1281 1282 C{cliques} acts as a I{stack}. The first clique which is allowed 1283 to send a message is the I{last} element of the list C{cliques}. 1284 Thus the root is the first element. 1285 (This is just a little quicker than starting with the first.) 1286 A clique can only send a message to a clique 'underneath' it in the stack. 1287 1288 No message to a clique in C{banned} is allowed to be sent. 1289 @param cliques: Stack of cliques 1290 @type cliques: List 1291 @param banned: Cliques which are banned from receiving messages 1292 @type banned: set 1293 """ 1294 if banned is None: 1295 banned = set() 1296 if cliques: 1297 clique = cliques.pop() 1298 for nbr in self._hypergraph.clique_neighbours(clique,banned): 1299 self.send_message(clique,nbr) 1300 banned.add(clique) 1301 self.send_messages(cliques,banned)
1302
1303 - def separator_items(self):
1304 """Return (separator,separator_factor) pairs for all separators 1305 1306 The order is arbitrary. 1307 1308 @return: (separator,separator_factor) pairs for all separators 1309 @rtype: List 1310 """ 1311 return self._separators.items()
1312 1313
1314 - def separator_factors(self):
1315 """Return an iterator over the factors associated with separators in the model 1316 1317 The order is arbitrary. 1318 @return: An iterator over the factors associated with separators in the model 1319 @rtype: Iterator 1320 """ 1321 return self._separators.itervalues()
1322
1323 - def trace(self):
1324 return self._hypergraph.trace
1325
1326 - def var_marginal(self,variable):
1327 """Compute the marginal distribution for C{variable}, 1328 assuming that C{self} is calibrated 1329 """ 1330 l = len(self._variables) 1331 for factor in self.factors_containing_variable(variable): 1332 this_l = len(factor.variables()) 1333 if this_l <= l: 1334 smallest_factor = factor 1335 l = this_l 1336 return CPT(smallest_factor.copy().marginalise_onto([variable]).normalised(),variable)
1337
1338 - def _clique_disp(self,clique_win,factor_dict):
1339 for child in clique_win.winfo_children(): 1340 child.destroy() 1341 fgs = [] 1342 for hyperedge, factor in factor_dict.items(): 1343 fgs.append(factor.gui_main(clique_win,edit=False)) 1344 fgs.append(Tkinter.Label(clique_win,text='*')) 1345 fgs.pop() 1346 for widget in fgs: 1347 widget.pack(side=Tkinter.LEFT)
1348
1349 1350 -class DN(SFR):
1351 """Dependency network representations of probability distributions 1352 1353 A SFR where the factors consist of a CPT 1354 for each variable. And where the associated digraph B{need not be acyclic}. 1355 1356 @ivar _adg: digraph, not necessarily acyclic despite the confusing name (!), giving the DN's structure 1357 @type _adg: L{DiGraph} object 1358 """ 1359
1360 - def __init__(self,factors=(),domain=None,new_domain_variables=None, 1361 must_be_new=False,check=False,dg=None):
1362 """DN constructor. 1363 Each factor has its domain expanded to be the domain of the model, if necessary. 1364 1365 If C{factors} is a sequence of factors then each is merely 1366 B{assumed} to be a CPT object, but there is no check done. 1367 1368 @param factors: The CPTs in the dependency network. 1369 (Alternatively an existing object of class L{FR} (or its subclasses), in 1370 which case C{self} has identical attributes, except possibly the attribute specifying its 1371 C{adg}.) 1372 @type factors: Sequence, each element of which is a L{Parameters.Factor} or 1373 L{Variables.SubDomain} object. (Alternatively an object of class L{FR} 1374 or one of its subclasses.) 1375 @param domain: A domain for the model. 1376 If None the internal default domain is used. 1377 @type domain: L{Variables.Domain} or None 1378 @param new_domain_variables: A dictionary containing a mapping from any new 1379 variables to their values. C{domain} is updated with these values 1380 @type new_domain_variables: Dict or None 1381 @param must_be_new: Whether domain variables in C{new_domain_variables} have 1382 to be new 1383 @type must_be_new: Boolean 1384 @param check: Whether to check that all variables exist in C{domain} 1385 @type check: Boolean 1386 @param dg: The digraph for the DN. This is simply B{assumed} to 1387 be the correct digraph, no check is made. If C{dg} is not supplied the correct 1388 digraph is created. 1389 @type dg: L{Graphs.DiGraph} 1390 @raise VariableError: If a variable in C{new_domain_variables} 1391 already exists with values different from 1392 its values in C{new_domain_variables}; 1393 Or if C{must_be_new} is set and the variable already exists. 1394 Or if C{check} is set and a variable in C{variables} is not in the domain 1395 @raise AttributeError: If C{adg} was not supplied and any of the supplied 1396 factors are not L{Parameters.CPT} objects. 1397 """ 1398 SFR.__init__(self,factors,domain,new_domain_variables,must_be_new,check) 1399 if dg is None: 1400 dg = DiGraph() 1401 for cpt in self: 1402 dg.put_family(cpt.child(),cpt.parents()) 1403 self._adg = dg
1404
1405 - def __getitem__(self,key):
1406 """Return the CPT (not a copy) corresponding to C{key} 1407 1408 C{key} can be all variables in 1409 CPT or just the child variable. The former is quicker. 1410 @param key: Family or child variable 1411 @type key: Iterable 1412 @return: The CPT for var 1413 @rtype: L{CPT} 1414 @raise KeyError: If there is no corresponding CPT. 1415 """ 1416 if isinstance(key,str): 1417 return self._factors[frozenset([key]) | self._adg.parents(key)] 1418 else: 1419 return SFR.__getitem__(self,key)
1420
1421 - def __repr__(self):
1422 """Formal string representation of a DN 1423 1424 @return: Formal string representation of a DN 1425 @rtype: String 1426 """ 1427 cpts = ','.join([cpt.repr_nodomain() for cpt in self._factors.values()]) 1428 dkt = dict([(v,self._domain[v]) for v in self._variables]) 1429 return 'DN([%s],Domain(new_domain_variables=%s),dg=%s)' % (cpts,dkt,repr(self._adg))
1430 1431
1432 - def __str__(self):
1433 """Print each CPT separately 1434 1435 Use lexicographical order on child name 1436 1437 @return: Pretty representation of a L{BN} 1438 @rtype: String 1439 """ 1440 out = '' 1441 tmp = [] 1442 for cpt in self._factors.values(): 1443 tmp.append((cpt.child(),cpt)) 1444 tmp.sort() 1445 for l, cpt in tmp: 1446 out += str(cpt) 1447 out += '\n' 1448 return out
1449
1450 - def add_cpts(self,cpts):
1451 """Adds new CPTs to a BN (or DN), if possible 1452 1453 The list C{cpts} is destroyed by this process 1454 1455 @param cpts: CPTs to add 1456 @type cpts: list 1457 @raise AttributeError: If any of the C{cpts} is not a L{Parameters.CPT} object 1458 @raise ValueError: If by adding these CPTs C{self} would no longer be a BN (or DN). 1459 """ 1460 while cpts: 1461 for i, cpt in enumerate(cpts): 1462 if (cpt.child() not in self._variables 1463 and cpt.parents() <= self._variables): 1464 self *= cpt 1465 for parent in cpt.parents(): 1466 self._adg.add_arrow(parent,cpt.child()) 1467 del cpts[i] 1468 break 1469 else: 1470 raise ValueError('Could not add CPTs: %s' % cpts)
1471 1472
1473 - def bdeu_score(self,data,precision=1.0):
1474 """ 1475 Return the BDeu score of C{self} (and its component scores) on C{data}. 1476 1477 @param data: The data 1478 @type data: L{Parameters.CompactFactor} 1479 @param precision: Prior precision 1480 @type precision: Float 1481 @return: C{(score,variable_scores)} where C{score} is the BDeu score and 1482 C{variable_scores} is a list of component scores, one for each variable (in order) 1483 in C{self}. 1484 @rtype: Tuple 1485 @raise NameError: If L{gPyC} was not successfully imported 1486 """ 1487 variable_scores = [] 1488 score = 0.0 1489 for variable in self.variables(): 1490 new_score = self[variable].get_counts(data).bdeu_score(precision) 1491 score += new_score 1492 variable_scores.append(new_score) 1493 return score, variable_scores
1494
1495 - def bdeu_score_using_other(self,data,other,other_score,other_variable_scores,precision=1.0):
1496 """ 1497 Return the BDeu score of C{self} (and its component scores) on C{data} 1498 using the existing BDeu score (and component scores) for C{other} 1499 computed on C{data} previously. 1500 1501 Clearly, for this to give the right answer, the existing score must have used the same C{precision}. 1502 1503 @param data: The data 1504 @type data: L{Parameters.CompactFactor} 1505 @param other: Previously scored BN 1506 @type other: L{BN} 1507 @param other_score: The BDeu score of C{other} on C{data} 1508 @type other_score: Float 1509 @param other_variable_scores: For each variable (in order) in C{other}, the 1510 component of the BDeu score for that variable 1511 @type other_variable_scores: List 1512 @param precision: Prior precision 1513 @type precision: Float 1514 @return: C{(score,variable_scores)} where C{score} is the BDeu score and 1515 C{variable_scores} is a list of component scores, one for each variable in 1516 C{self}. 1517 @rtype: Tuple 1518 @raise NameError: If L{gPyC} was not successfully imported 1519 """ 1520 variable_scores = other_variable_scores 1521 score = other_score 1522 for i, variable in enumerate(self.variables()): 1523 if self.adg.parents(variable) != other.adg.parents(variable): 1524 new_score = self[variable].get_counts(data).bdeu_score(precision=1.0) 1525 score += (new_score - other_variable_scores[i]) 1526 variable_scores[i] = new_score 1527 return score, variable_scores
1528
1529 - def children(self,variable):
1530 """Return the children of C{variable} 1531 1532 @param variable: Variable in the BN (or DN) 1533 @type variable: Immutable (usually a string) 1534 @return: The children of C{variable} 1535 @rtype: set 1536 @raise KeyError: If C{variable} is not in the BN (or DN). 1537 """ 1538 return self._adg.children(variable)
1539
1540 - def gibbs_sample(self,evidence=None,burnin=100,iterations=1000,inst=None,order=None):
1541 """ 1542 Run Gibbs sampling 1543 """ 1544 if evidence is None: 1545 evidence = {} 1546 1547 parent_indices, cpts, order = self._info_for_sampling_dn(evidence,order) 1548 1549 if inst is None: 1550 inst = [member(self.values(v)) for v in self.variables()] 1551 1552 indices = [] 1553 uninstantiated = set() 1554 for i, v in enumerate(order): 1555 indices.append(i) 1556 if v in evidence: 1557 inst[i] = evidence[v] 1558 else: 1559 uninstantiated.add(i) 1560 1561 t = 0 1562 while t < burnin: 1563 for i in indices: 1564 parent_inst = tuple([inst[j] for j in parent_indices[i]]) 1565 if i in uninstantiated: 1566 inst[i] = cpts[i].sample(parent_inst) 1567 t += 1 1568 1569 t = 0 1570 while t < iterations: 1571 for i in indices: 1572 parent_inst = tuple([inst[j] for j in parent_indices[i]]) 1573 if i in uninstantiated: 1574 inst[i] = cpts[i].sample(parent_inst) 1575 yield inst 1576 t += 1
1577 1578 1579 # uninstantiated = [] 1580 # cpts = [] 1581 # parent_indices = [] 1582 # dkt = {} 1583 # for i, variable in enumerate(order): 1584 # dkt[variable] = i 1585 # for i, variable in enumerate(order): 1586 # if variable in self._instd: 1587 # inst[i] = member(self._domain[variable]) 1588 # cpts.append(None) 1589 # parent_indices.append(None) 1590 # else: 1591 # uninstantiated.append(i) 1592 # cpt = self[variable] 1593 # cpt.initialise_sampler() 1594 # cpts.append(cpt) 1595 # parents = sorted(cpt.parents()) 1596 # parent_indices.append([dkt[p] for p in parents]) 1597 1598 # while t < iterations: 1599 # for i in uninstantiated: 1600 # inst[i] = cpts[i].sample( 1601 # tuple([inst[j] for j in parent_indices[i]])) 1602 1603 # if collecting: 1604 # sample[t] = inst 1605 # t += 1 1606 # else: 1607 # burnin -= 1 1608 # if burnin == 0: 1609 # collecting = True 1610 1611 # return sample 1612
1613 - def parents(self,variable):
1614 """Return the parents of C{variable} 1615 1616 @param variable: Variable in the BN (or DN) 1617 @type variable: Immutable (usually a string) 1618 @return: The parents of C{variable} 1619 @rtype: set 1620 @raise KeyError: If C{variable} is not in the BN (or DN). 1621 """ 1622 return self._adg.parents(variable)
1623
1624 - def _info_for_sampling_dn(self,evidence=None,order=None):
1625 if evidence is None: 1626 evidence = {} 1627 if order is None: 1628 order = tuple(self._variables) 1629 dkt = dict(zip(order,range(len(order)))) 1630 parent_indices = [] 1631 cpts = [] 1632 for variable in order: 1633 cpt = self[variable] 1634 parent_indices.append([dkt[p] for p in sorted(cpt.parents())]) 1635 if variable not in evidence: 1636 cpt.initialise_sampler() 1637 cpts.append(cpt) 1638 1639 return parent_indices, cpts, order
1640
1641 -class BN(DN):
1642 """Bayesian network representations of probability distributions 1643 1644 A SFR where the factors consist of a CPT 1645 for each variable. And where the associated digraph B{must be acyclic}. 1646 1647 Joint is thus exactly the product of the factors 1648 @ivar _adg: The acyclic digraph giving the BN's structure 1649 @type _adg: L{ADG} object 1650 """ 1651
1652 - def __init__(self,factors=(),domain=None,new_domain_variables=None, 1653 must_be_new=False,check=False,adg=None):
1654 """BN constructor. 1655 Each factor has its domain expanded to be the domain of the model, if necessary. 1656 1657 If C{factors} is a sequence of factors then each is merely 1658 B{assumed} to be a CPT object, but there is no check done. 1659 1660 @param factors: The CPTs in the Bayesian network. 1661 (Alternatively an existing object of class L{FR} (or its subclasses), in 1662 which case C{self} has identical attributes, except possibly the attribute specifying its 1663 C{adg}.) 1664 @type factors: Sequence, each element of which is a L{Parameters.Factor} or 1665 L{Variables.SubDomain} object. (Alternatively an object of class L{FR} 1666 or one of its subclasses.) 1667 @param domain: A domain for the model. 1668 If None the internal default domain is used. 1669 @type domain: L{Variables.Domain} or None 1670 @param new_domain_variables: A dictionary containing a mapping from any new 1671 variables to their values. C{domain} is updated with these values 1672 @type new_domain_variables: Dict or None 1673 @param must_be_new: Whether domain variables in C{new_domain_variables} have 1674 to be new 1675 @type must_be_new: Boolean 1676 @param check: Whether to check that all variables exist in C{domain} 1677 @type check: Boolean 1678 @param adg: The acyclic digraph for the BN. This is simply B{assumed} to 1679 be the correct ADG, no check is made. If C{adg} is not supplied the correct ADG 1680 is created if possible, otherwise a C{DirectedCycleError} is raised. 1681 @type adg: L{Graphs.ADG} 1682 @raise VariableError: If a variable in C{new_domain_variables} 1683 already exists with values different from 1684 its values in C{new_domain_variables}; 1685 Or if C{must_be_new} is set and the variable already exists. 1686 Or if C{check} is set and a variable in C{variables} is not in the domain 1687 @raise AttributeError: If C{adg} was not supplied and any of the supplied 1688 factors are not L{Parameters.CPT} objects. 1689 @raise DirectedCycleError: If C{adg} was not supplied and there is no ADG 1690 consistent with the CPTs provided. 1691 """ 1692 SFR.__init__(self,factors,domain,new_domain_variables,must_be_new,check) 1693 if adg is None: 1694 adg = ADG() 1695 for cpt in self: 1696 adg.put_family(cpt.child(),cpt.parents()) 1697 self._adg = adg
1698 1699
1700 - def __repr__(self):
1701 """Formal string representation of a BN 1702 1703 @return: Formal string representation of a BN 1704 @rtype: String 1705 """ 1706 cpts = ','.join([cpt.repr_nodomain() for cpt in self._factors.values()]) 1707 dkt = dict([(v,self._domain[v]) for v in self._variables]) 1708 return 'BN([%s],Domain(new_domain_variables=%s),adg=%s)' % (cpts,dkt,repr(self._adg))
1709 1710 1711
1712 - def _add_factor(self,factor):
1713 """Include a factor in a BN 1714 1715 If the factor is a CPT, its child is a new variable and all 1716 its parents are existing variables, then C{self} 1717 remains a BN, otherwise it becomes a L{FR} object 1718 """ 1719 if (isinstance(factor,CPT) and 1720 factor.child() not in self._variables and 1721 factor.parents() <= self._variables): 1722 self._adg.put_family(factor.child(),factor.parents()) 1723 else: 1724 self.__class__ = FR 1725 del self._adg 1726 FR._add_factor(self,factor)
1727
1728 - def adg(self):
1729 """Return the ADG associated with BN 1730 1731 A copy of the BN's ADG is returned. 1732 @return: The ADG associated with BN 1733 @rtype: L{Graphs.ADG} 1734 """ 1735 return self._adg.copy()
1736
1737 - def condition(self,condition,keep_class=False):
1738 """Alter a Bayesian network by effecting the restriction on 1739 variables given by C{condition} 1740 1741 Generally conditioning a BN produces something which is not 1742 a BN, see C{keep_class} argument. 1743 1744 This alters the model's domain. Make a copy with C{copy_domain=True} 1745 if the original domain will be needed 1746 @param condition: Dictionary of the form {var1:values1,var2:values2..} 1747 Each value of this dictionary must be an iterable 1748 @type condition: Dict 1749 @param keep_class: If C{False}, the object will cease to be a 1750 L{BN} object, becoming a {SFR} object, and all its factors will become 1751 L{Parameters.Factor} object. If C{True} remains a BN. 1752 @type keep_class: Boolean 1753 @return: The conditioned model 1754 @rtype: Same as C{self} 1755 @raise KeyError: If a variable is used that is not in the model 1756 @raise ValueError: If a value is used that is not a possible value of 1757 the variable it is attached to 1758 """ 1759 1760 SFR.condition(self,condition,keep_class) 1761 if not keep_class: 1762 self.__class__ = SFR 1763 return self
1764
1765 - def copy(self,copy_domain=False):
1766 """ 1767 Return a copy 1768 1769 @return: A copy of C{self} 1770 @rtype: L{BN} 1771 """ 1772 cp = SFR.copy(self,copy_domain) 1773 cp._adg = self._adg.copy() 1774 return cp
1775
1776 - def eliminate_variable(self,variable,trace=False):
1777 """Alter a Bayesian network by summing out a variable 1778 1779 C{self} will cease to be a L{BN} object unless C{variable} 1780 has no children. 1781 @param variable: The variable to eliminate 1782 @type variable: String 1783 """ 1784 if not self._adg.children(variable): 1785 if trace: 1786 return SFR.eliminate_variable(self,variable,True) 1787 self.remove(variable) 1788 else: 1789 self.__class__ = SFR 1790 return self.eliminate_variable(variable,trace)
1791
1792 - def family_hyperedge(self,child):
1793 """Return the hyperedge corresponding to the CPT with C{child} 1794 as child 1795 1796 @param child: Variable 1797 @type child: Immutable 1798 @return: Hyperedge corresponding to the CPT with C{child} 1799 as child 1800 @rtype: Frozenset 1801 """ 1802 return frozenset([child]) | self._adg.parents(child)
1803
1804 - def _info_for_sampling(self,evidence=None):
1805 order = self.topological_order() 1806 dkt = dict(zip(order,range(len(order)))) 1807 parent_indices = [] 1808 cpts = [] 1809 for variable in order: 1810 cpt = self[variable] 1811 parent_indices.append([dkt[p] for p in sorted(cpt.parents())]) 1812 if evidence is None or variable not in evidence: 1813 cpt.initialise_sampler() 1814 else: 1815 ev_val = evidence[variable] 1816 for i, val in enumerate(sorted(cpt.values(variable))): 1817 if val == ev_val: 1818 break 1819 tmp = {} 1820 itr = cpt.parent_insts_data() 1821 for parent_inst in cpt.parent_insts(): 1822 tmp[tuple(parent_inst)] = itr.next()[i] 1823 cpt = tmp 1824 cpts.append(cpt) 1825 1826 1827 return parent_indices, cpts, order
1828
1829 - def likelihood_weighting(self,evidence,iterations=1000):
1830 1831 parent_indices, cpts, order = self._info_for_sampling(evidence) 1832 1833 indices = [] 1834 uninstantiated = set() 1835 inst = [] 1836 for i, v in enumerate(order): 1837 indices.append(i) 1838 if v in evidence: 1839 inst.append(evidence[v]) 1840 else: 1841 uninstantiated.add(i) 1842 inst.append(None) 1843 1844 t = 0 1845 while t < iterations: 1846 weight = 1 1847 for i in indices: 1848 parent_inst = tuple([inst[j] for j in parent_indices[i]]) 1849 if i in uninstantiated: 1850 inst[i] = cpts[i].sample(parent_inst) 1851 else: 1852 weight *= cpts[i][parent_inst] 1853 yield weight, inst 1854 t += 1
1855
1856 - def forward_sample(self,iterations=1000):
1857 1858 parent_indices, cpts, order = self._info_for_sampling() 1859 1860 n = len(order) 1861 indices = range(n) 1862 inst = [None] * n 1863 t = 0 1864 while t < iterations: 1865 for i in indices: 1866 inst[i] = cpts[i].sample( 1867 tuple([inst[j] for j in parent_indices[i]])) 1868 yield inst 1869 t += 1
1870
1871 - def from_bif(self,info):
1872 """Add CPTs from a C{.bif} file 1873 1874 @param info: Output from L{IO.read_bif} 1875 @type info: Tuple 1876 1877 """ 1878 varvalues, parents, cpts = info 1879 self.add_domain_variables(varvalues) 1880 new_cpts = [] 1881 for var, parents in parents.items(): 1882 cpt = CPT(Factor(parents+(var,),domain=self),child=var) 1883 for inst, probs in cpts[var].items(): 1884 dkt = dict(zip(parents,inst)) 1885 for i, pr in enumerate(probs): 1886 dkt[var] = varvalues[var][i] 1887 cpt[dkt] = pr 1888 new_cpts.append(cpt) 1889 self.add_cpts(new_cpts)
1890
1891 - def from_dnet(self,info):
1892 """Add CPTs from a C{.dnet} file 1893 1894 @param info: Output from L{IO.read_dnet} 1895 @type info: Tuple 1896 1897 """ 1898 try: 1899 dnet_variables, named_cpts = info 1900 except ValueError: 1901 dnet_variables, named_cpts, coords = info 1902 cpts = [] 1903 for variable, (parents,data) in named_cpts.items(): 1904 cpts.append(CPT( 1905 Factor(parents+[variable],data,self, 1906 new_domain_variables=dnet_variables, 1907 check=True),variable,cpt_check=True)) 1908 self.add_cpts(cpts) 1909 try: 1910 self._adg.set_vertex_positions(coords) 1911 except NameError: 1912 pass
1913 1914 1915
1916 - def remove(self,key):
1917 """Remove a CPT where C{key} is either the child or a tuple giving 1918 the family for the CPT. 1919 1920 Using the child as key is faster. 1921 @raise TypeError: If C{key} is neither a variable or a family 1922 """ 1923 key_hyperedge = frozenset(key) 1924 if key_hyperedge in self._factors: 1925 hyperedge = key_hyperedge 1926 for child in hyperedge: 1927 parents = self._adg.parents(child) 1928 if hyperedge == parents | set([child]): 1929 break 1930 else: 1931 child = key 1932 hyperedge = self.family_hyperedge(child) 1933 if not self._adg.children(child): 1934 self._adg.remove_vertex(child) 1935 else: 1936 # can do this since a SFR has same _factors dictionary 1937 self.__class__ = SFR 1938 del self._adg 1939 SFR.remove(self,hyperedge)
1940 1941
1942 - def sample(self,fobj,samples=100):
1943 """Write a sample in CSV form to C{fobj} 1944 1945 Uses forward sampling, so assumes no instantiations 1946 1947 @param fobj: Writable file 1948 @type fobj: File 1949 @param samples: How may samples to produce 1950 @type samples: Int 1951 """ 1952 from gPy.Samplers import BNSampler 1953 for v in self._variables: 1954 print >>fobj, '%s:%s' % (v, ','.join(self._domain[v])) 1955 sampler = BNSampler(self) 1956 print >>fobj, ','.join(sampler.variables()) 1957 while samples: 1958 print >>fobj, ','.join(sampler.forward_sample()) 1959 samples -= 1
1960
1961 - def sample_sqlite(self,dbfilename=':memory',table='data',samples=100):
1962 """Write a sample to a sqlite database 1963 1964 Uses forward sampling, so assumes no instantiations 1965 1966 @param dbfilename: Name of sqlite database 1967 @type dbfilename: String 1968 @param samples: How may samples to produce 1969 @type samples: Int 1970 @return: sqlite database 1971 @rtype: L{Parameters.SqliteFactor} object 1972 """ 1973 from gPy.Samplers import BNSampler 1974 from gPy.Parameters import SqliteFactor 1975 sqlfactor = SqliteFactor(self.variables(),dbfilename,table,domain=self) 1976 newself = self.copy(True) # deep copy 1977 dom = newself._domain 1978 # replace values with numbers 1979 for v, vals in dom.items(): 1980 dom[v] = range(len(vals)) 1981 sampler = BNSampler(newself) 1982 dkt = {} 1983 while samples: 1984 inst = tuple(sampler.forward_sample()) 1985 try: 1986 dkt[inst] += 1 1987 except KeyError: 1988 dkt[inst] = 1 1989 sqlfactor.populate([inst+(count,) for inst,count in dkt.items()], 1990 sampler.variables()) 1991 return sqlfactor
1992 1993
1994 - def topological_order(self):
1995 """A topological ordering of the associated acyclic digraph 1996 1997 Children come after their parents in the ordering 1998 @return: A topological ordering of the vertices 1999 @rtype: List 2000 """ 2001 return self._adg.topological_order()
2002
2003 # def sample(self,times=1): 2004 # """UNIMPLEMENTED 2005 2006 # Simple rejection sampling from the (perhaps conditional) joint 2007 # Generated sample may be smaller than times due to rejections. 2008 # """ 2009 # sample = [] 2010 # count = 0 2011 # save_instantiation = self.copy_instantiation() 2012 # while count < times: 2013 # inst_list = [] 2014 # for v in self.topological_order: 2015 # value = self[v].sample_cpt() 2016 # if value == None: 2017 # break 2018 # self.update_instantiation({v:[value]}) 2019 # inst_list.append(value) 2020 # else: 2021 # sample.append(tuple(inst_list)) 2022 # count += 1 2023 # self.update_instantiation(save_instantiation) 2024 # return sample 2025 2026 # def _gettopological_order(self): 2027 # try: 2028 # return self._topological_order 2029 # except AttributeError: 2030 # self._topological_order = self.adg.topological_order() 2031 # return self._topological_order 2032 # def _deltopological_order(self): del self._topological_order 2033 # _doctopoligical_order = """ 2034 # A topological ordering of the associated acyclic digraph 2035 2036 # Computed on demand and then cached. 2037 # """ 2038 # topological_order = property(fget=_gettopological_order, 2039 # fdel=_deltopological_order, 2040 # doc=_doctopoligical_order) 2041 2042 # self._hypergraph = dm._hypergraph 2043 # self._factors = dm._factors 2044 # join_forest = dm.hypergraph().join_forest() 2045 # TODO!! 2046 # fs = {} 2047 # seps = {} 2048 # for hyperedge in join_forest: 2049 # fs[hyperedge] = Factor(hyperedge) 2050 # for edge in join_forest.edges: 2051 # clique1, clique2 = tuple(edge) 2052 # seps[edge] = Factor(clique1 & clique2) 2053 # destination = join_forest.destination 2054 # for hyperedge, factor in hm.items(): 2055 # fs[destination[hyperedge]] *= factor 2056 # self._factors = fs 2057 # self._hypergraph = join_forest 2058 # self._separators = seps 2059 2060 2061 2062 -class CBN(BN):
2063 - def __init__(self,factors=(),domain=None,adg=None):
2064 super(CBN,self).__init__(factors=factors,domain=domain,adg=adg)
2065 2066 @staticmethod
2067 - def from_bn(bn):
2068 """Construct a L{CBN} from a L{BN}.""" 2069 self = BN() 2070 self.__dict__ = bn.__dict__ 2071 self.__class__ = CBN 2072 self._adg = MutilatedADG.from_adg(self._adg) 2073 return self
2074 2075 @staticmethod
2076 - def from_adg_data(adg, data, prior=1):
2077 """Construct a L{CBN} from an L{ADG} and estimates of the parameters from 2078 some observations. 2079 @type adg: L{ADG} 2080 @type data: L{CausalWorld} 2081 @param prior: the Dirichlet prior parameter (the same parameter value 2082 is used for all instances!) Note there may be some problems with 2083 this method: a B{different} prior is used by the BDeu score. However, 2084 in practice, for parameter estimation, this prior method seems to be ok. 2085 I was lazy and it was simple to implement (cb). If prior is zero, then 2086 the parameters are the maximum likelihood estimation solutions. 2087 """ 2088 cpts = [] 2089 for child in data.variables(): 2090 cpts.append(data.makeCPT(child,adg.parents(child),force_cpt=True, prior=prior)) 2091 self = CBN(factors=cpts, domain=data, adg=adg.copy()) 2092 return self
2093
2094 - def _mutilate(self, variables):
2095 for variable in variables: 2096 child = frozenset([variable]) 2097 n = self._numvals[variable] 2098 dat = [1/float(n) for i in xrange(n)] 2099 self._replace_factor( variable 2100 , Factor(variables=child,data=dat,domain=self).makeCPT(variable) 2101 , allow_hyperedge_change=True) 2102 parents = self._adg.parents(variable) 2103 self._hypergraph.remove_hyperedge(child | parents) 2104 for parent in parents: 2105 self._adg.remove_arrow(parent, variable) 2106 self._hypergraph.add_hyperedge(child)
2107
2108 - def intervene(self, intervention):
2109 """ 2110 @param intervention: A dictionary mapping variables in the L{CBN} to 2111 a single value in the domain. 2112 """ 2113 self._mutilate(intervention.keys()) 2114 self.condition(intervention, keep_class=True)
2115
2116 - def _replace_factor(self, child, factor, allow_hyperedge_change=False):
2117 """Replace the factor of C{child} with C{factor}, updating the 2118 hypergraph. Note the ADG must be update separately!""" 2119 hyperedge = self.family_hyperedge(child) 2120 new_hyperedge = frozenset(factor.variables()) 2121 if new_hyperedge != hyperedge and not allow_hyperedge_change: 2122 raise ValueError,'new factor does not have same variables as hyperedge',hyperedge 2123 del self._factors[hyperedge] 2124 self._factors[new_hyperedge] = factor
2125
2126 - def estimate_parameters(self, data, prior=1.0):
2127 """Replace the parameters of C{self} with estimates from C{data}, keeping the 2128 same structure. 2129 @param prior: the Dirichlet prior parameter (the same parameter value 2130 is used for all instances!) Note there may be some problems with 2131 this method: a B{different} prior is used by the BDeu score. However, 2132 in practice, for parameter estimation, this prior method seems to be ok. 2133 I was lazy and it was simple to implement (cb). If prior is zero, then 2134 the parameters are the maximum likelihood estimation solutions. 2135 """ 2136 for child in self._variables: 2137 self._replace_factor(child, data.makeCPT(child,self._adg.parents(child),force_cpt=True, prior=prior))
2138
2139 - def extend_to(self, other):
2140 """Extend the domains of each of the variables in C{other} 2141 to the corresponding values. The parameters corresponding to values 2142 introduced to the domains are set to zero. When conditioning, L{BN} 2143 removes elements from the domains of variables. Sometimes it is useful to 2144 add the zeroes e.g. when using an additive binary operator. 2145 @param other: a dictionary mapping variables to their domains. Note 2146 the current domain must be a subset of the new domain. 2147 """ 2148 for factor in self: 2149 extension = dict([(variable, other.values(variable)) 2150 for variable in other.variables() & factor.variables()]) 2151 factor.data_extend(extension, keep_class=True) 2152 2153 # factors should have the same domain as self 2154 for variable in other.variables(): 2155 self.change_domain_variable(variable, other.values(variable))
2156
2157 - def good_interventions(self):
2158 """Obtain a heuristically ``good'' set of interventions from which 2159 structures can be learnt.""" 2160 forces_vars = self.good_sets_forced_variables() 2161 # intervene at these variables, trying every combination of value 2162 interventions = [] 2163 for forced_vars in forces_vars: 2164 forced_vars = tuple(forced_vars) 2165 for inst in self.insts(forced_vars): 2166 inst = map(lambda x: frozenset([x]), inst) 2167 interventions.append(dict(zip(forced_vars,inst))) 2168 return interventions
2169
2170 - def good_sets_forced_variables(self):
2171 """Return the good independent sets of variables to force. This is an 2172 approximation since ideally you want the edge which implies the most 2173 propagation when resolved. The approximation will resolve all edges. 2174 However some forcing may be unnecessary.""" 2175 def fanout_cmp(a,b): 2176 return cmp(fanout[b], fanout[a])
2177 interventions = [] 2178 g = EssentialGraph.from_graph(self._adg.essential_graph()) 2179 while g.vertices(): 2180 # propagate any implied orientations 2181 g.resolve() 2182 2183 # set of vertices adjacent to force variables 2184 blanket = set() 2185 2186 # map from vertex to fanout 2187 fanout = {} 2188 2189 # add another set of forced variables to the list 2190 interventions.append(set()) 2191 2192 # determine the fan-out of each node 2193 for v in g.vertices(): 2194 f = len(g.neighbours(v)) 2195 if f < 1: 2196 g.remove_vertex(v) 2197 continue 2198 fanout[v] = f 2199 2200 # order nodes by their fan-out 2201 fanout_order = sorted(fanout.keys(), cmp=fanout_cmp) 2202 2203 # intervene at the highest fan-out nodes first 2204 for v in fanout_order: 2205 # ... provided they are not adjacent to previous force 2206 # variables 2207 if v in blanket: 2208 continue 2209 interventions[-1].add(v) 2210 # update the adjacent nodes 2211 blanket |= g.neighbours(v) 2212 # remove this node so that it is not included in the next 2213 # intervention fan-out calculations 2214 g.remove_vertex(v) 2215 return interventions
2216