1 """
2 Factored representations of probability distributions and probability models
3
4 @var _version: Version of this module
5 @type _version: String
6 """
7
8 from Variables import SubDomain
9 from Hypergraphs import Hypergraph, ReducedHypergraph, SimpleHypergraph, ReducedJoinForest
10 from Graphs import ADG, DiGraph, DiForest, MutilatedADG, EssentialGraph
11 from IO import GraphCanvas
12 from Utils import emptyset, pretty_str_set, member
13 from Parameters import Factor, CPT
14 from Data import CompactFactor
15 import Tkinter, operator
16
17 _version = '$Id: Models.py,v 1.14 2008/10/07 09:11:55 jc Exp $'
18
19 -class _AbsM(SubDomain):
20 """Abstract model class
21
22 Defined by a hypergraph and specification of values for the variables.
23 Represents the set of all probability distributions which have a factored representation whose
24 hypergraph is the specifed one.
25
26 Since the hypergraph is not required to be simple it is not useful to actually create objects
27 of this class, hence it is an abstract class.
28 The most general model class is that of log-linear models (LLM objects, not yet implemented) where the
29 associated hypergraph is required to be simple.
30
31 @ivar _hypergraph: The hypergraph associated with the model. There is exactly
32 one hyperedge for each factor. The hyperedge for a given factor is just the set of that
33 factor's variables (implemented as a frozenset). If the hypergraph is non-simple, then hyperedges
34 can be repeated.
35 @type _hypergraph: L{Hypergraph} object
36 @ivar _factors: A mapping from each distinct hyperedge in the associated hypergraph to a list of
37 factors having the variables of the hyperedge.
38 @type _factors: Dictionary
39 """
40
41 - def __init__(self,factors=(),domain=None,new_domain_variables=None,
42 must_be_new=False,check=False):
43 """Initialise a hierarchical model
44
45 Each factor has its domain expanded to be the domain of the model, if necessary.
46 @param factors: The factors in the hierarchical model.
47 (Alternatively an existing object of class L{FR} (or its subclasses), in
48 which case C{self} has identical attributes.)
49 @type factors: Sequence, each element of which is a L{Parameters.Factor} or
50 L{Variables.SubDomain} object. (Alternatively an object of class L{FR}
51 or one of its subclasses.)
52 @param domain: A domain for the model.
53 If None and all C{factors} have the same domain then this domain is used.
54 If None and all C{factors} do not have the same domain then the internal default domain is used.
55 @type domain: L{Variables.Domain} or None
56 @param new_domain_variables: A dictionary containing a mapping from any new
57 variables to their values. C{domain} is updated with these values
58 @type new_domain_variables: Dict or None
59 @param must_be_new: Whether domain variables in C{new_domain_variables} have
60 to be new
61 @type must_be_new: Boolean
62 @param check: Whether to check that all variables exist in C{domain}
63 @type check: Boolean
64 @raise VariableError: If a variable in C{new_domain_variables}
65 already exists with values different from
66 its values in C{new_domain_variables};
67 Or if C{must_be_new} is set and the variable already exists.
68 Or if C{check} is set and a variable in C{variables} is not in the domain
69 """
70 if isinstance(factors,FR):
71 self.__dict__ = factors.__dict__
72 return
73 self._factors = {}
74 hypergraph = Hypergraph()
75 variables = set()
76 for factor in factors:
77 hyperedge = factor.variables()
78 variables.update(hyperedge)
79 hypergraph.add_hyperedge(hyperedge)
80 self._add_factor_to_dict(hyperedge,factor)
81 if domain is None and len(factors) > 0:
82 first_factor_domain = factors[0]._domain
83 for factor in factors[1:]:
84 if factor._domain is not first_factor_domain:
85 break
86 else:
87 domain = factors[0]
88 SubDomain.__init__(self,variables,domain,new_domain_variables,must_be_new,check)
89 for factor in factors:
90 self.common_domain(factor)
91 self._hypergraph = hypergraph
92
94 """Return the factor(s) (not a copy) whose variables are C{hyperedge}
95
96 @param hyperedge: The variables of the sought factor
97 @type hyperedge: Iterable
98 @return: The factor(s)
99 @rtype: A L{Parameters.Factor} object (for simple) or a list of such objects (non-simple)
100 @raise KeyError: If there is no factor with these variables
101 """
102 return self._factors[frozenset(hyperedge)]
103
105 """Return an iterator over the factors in the model
106
107 To allow C{for factor in model: ...} constructions.
108 @return: An iterator over the factors in the model
109 @rtype: Iterator
110 """
111 fs = []
112 for factorlist in self._factors.values():
113 fs.extend(factorlist)
114 return iter(fs)
115
117 """Return the number of factors in the model
118
119 @return: The number of factors in the model
120 @rtype: Int
121 """
122 return len(self._hypergraph)
123
125 return 'FR(%s,None,%s)' % (self._factors.values(),
126 self._domain)
127
129 """Set the factor whose variables are C{hyperedge} to factor
130
131 @param hyperedge: The variables of the factor
132 @type hyperedge: Iterable
133 @param factor: Factor which will replace existing factor(s)
134 @type factor: L{Parameters.Factor}
135 @raise KeyError: If there is no factor with these variables
136 """
137 hyperedge = frozenset(hyperedge)
138 if hyperedge != factor.variables():
139 raise ValueError('Can only replace factors with other factors with the same variables: %s, %s' % (hyperedge, factor.variables()))
140 self._put_factor_to_dict(hyperedge,factor)
141
142
144 """Print each factor separately
145
146 Use lexicographical ordering on factor variables
147
148 @return: Pretty representation of a L{FR}
149 @rtype: String
150 """
151 out = ''
152 tmp = []
153 for hyperedge,factor in self.items():
154 l = list(hyperedge)
155 l.sort()
156 tmp.append((l,factor))
157 tmp.sort()
158 for l, factor in tmp:
159 out += str(factor)
160 out += '\n'
161 return out
162
164 """Return the result of dividing a hierarchical model by a scalar
165
166 A randomly chosen factor is divided by the scalar
167
168 The returned value shares the same domain as C{self}. To avoid this make a deep copy and do
169 in-place division.
170
171 @param other: The scalar
172 @type other: float or int
173 """
174 return self.copy().__idiv__(other)
175
176
178 """Divide a hierarchical model by a scalar
179
180 A randomly chosen factor is divided by the scalar
181
182 @param other: The scalar
183 @type other: float or int
184 """
185 for factor in self:
186 factor /= other
187 break
188 return self
189
190
192 """Multiply a hierarchical model by a factor, scalar or another L{FR}
193
194 @param other: The factor or L{FR} object to be multiplied in
195 @type other: L{Factor} of L{FR} object
196 """
197 if isinstance(other,Factor):
198 self._add_factor(other)
199 elif isinstance(other,FR):
200 for factor in other:
201 self._add_factor(factor)
202 else:
203 for factor in self:
204 factor *= other
205 break
206 return self
207
209 """Return the result of multiplying a hierarchical model by a factor, scalar or another L{FR}
210
211 The returned value shares the same domain as C{self}. To avoid this make a deep copy and do
212 in-place multiplication.
213
214 @param other: The factor or L{FR} object to be multiplied in
215 @type other: L{Factor} of L{FR} object
216 """
217 return self.copy().__imul__(other)
218
220 """Add an ident factor for a hyperedge to a hierarchical model
221
222 Typically C{hyperedge} will come from a L{Hypergraph} model.
223
224 An ident factor maps all instantiations to 1.0
225 @param hyperedge: The variables for the factor
226 @type hyperedge: Iterable
227 """
228 hyperedge = frozenset(hyperedge)
229 self._hypergraph.add_hyperedge(hyperedge)
230 self._add_factor_to_dict(hyperedge,Factor(hyperedge))
231 self._variables = frozenset(self._hypergraph.vertices())
232
233 - def condition(self,condition,keep_class=False):
234 """Alter a distribution by effecting the restriction on
235 variables given by C{condition}
236
237 This alters the model's domain. Make a copy with C{copy_domain=True}
238 if the original domain will be needed
239 @param condition: Dictionary of the form {var1:values1,var2:values2..}
240 Each value of this dictionary must be an iterable
241 @type condition: Dict
242 @return: The conditioned model
243 @param keep_class: TODO
244 @type keep_class: Boolean
245 @rtype: Same as C{self}
246 @raise KeyError: If a variable is used that is not in the model
247 @raise ValueError: If a value is used that is not a possible value of
248 the variable it is attached to
249 """
250 self._check_condition(condition)
251
252 for factor in self:
253 factor.data_restrict(condition,keep_class)
254
255 self.change_domain_variables(condition)
256 return self
257
258 - def copy(self,copy_domain=False):
259 """Return a deep copy of a hierarchical model
260
261 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy
262 shares C{self}'s domain
263 @type copy_domain: Boolean
264 @return: A copy of C{self}
265 @rtype: Same type as C{self}
266 """
267 cp = SubDomain.copy(self,copy_domain)
268 cp.__class__ = self.__class__
269 cp._hypergraph = self._hypergraph.copy()
270 factors = {}
271 for hyperedge, factorlist in self._factors.items():
272 nw = []
273 for factor in factorlist:
274 fc = factor.copy()
275 cp.common_domain(fc)
276 nw.append(fc)
277 factors[hyperedge] = nw
278 cp._factors = factors
279 return cp
280
281 - def cpt(self,child,parents=()):
282 """Return a conditional probability table for specified child and parents
283
284 Return the probability distribution of C{child} conditional on C{parents}
285
286 @param child: Child of the CPT
287 @type child: String
288 @param parents: Parents of C{child} in the CPT
289 @type parents: Sequence
290 @return: The specified CPT
291 @rtype: L{Parameters.CPT}
292 """
293 cp = self.copy()
294 vs = tuple(parents) + ((child),)
295 cp.marginal(vs)
296 joint = 1
297 for f in cp:
298 joint *= f
299 return CPT(joint,child,cpt_force=True)
300
301
303 """Alter a factored representation by summing out a variable
304
305 Removes one or more factors.
306 @param variable: The variable to eliminate
307 @type variable: String
308 @raise KeyError: If C{variable} is not in the model
309 """
310 hyperedges = self._hypergraph.star(variable)
311 if variable in self._instd:
312 for hyperedge in hyperedges:
313
314 nf = self.factor(hyperedge).drop_variable(variable)
315
316 self.remove(hyperedge)
317
318 self *= nf
319 if trace:
320 return None, None, hyperedges
321 else:
322 return
323 prod_factor = 1
324 for hyperedge in hyperedges:
325 prod_factor *= self.factor(hyperedge)
326 self.remove(hyperedge)
327 message = prod_factor.sumout([variable])
328 self *= message
329 if trace:
330 return prod_factor.variables(), message.variables(), hyperedges
331
333 """Return the factor produced by multiplying all factors with variables C{hyperedge}
334
335 @param hyperedge: Set of variables
336 @type hyperedge: Iterable
337 @return: Product of all factors with variables C{hyperedge}
338 @rtype: L{Parameters.Factor} object
339 @raise KeyError: if no factor has C{hyperedge} as variables.
340 """
341 factors = self._factors[frozenset(hyperedge)]
342 prod = factors[0]
343 for f in factors[1:]:
344 prod *= f
345 return prod
346
348 """
349 Return a list of the factors in the model
350
351 @return: A list of the factors in the model
352 @rtype: List
353 """
354 fs = []
355 for factorlist in self._factors.values():
356 fs.extend(factorlist)
357 return fs
358
360 """Return a list of factors containing C{variable}
361
362 If non-simple, then a list of lists of factors is returned
363
364 @param variable: A variable
365 @type variable: Immutable (usually string)
366 @return: A list of factors containing C{variable}
367 @rtype: List
368 """
369 fs = []
370 for hyperedge in self._hypergraph[variable]:
371 fs.extend(self._factors[hyperedge])
372 return fs
373
374
376 """Return sequence of C{factor.variables(),factor} pairs
377 for each factor in the model.
378
379 @return: Sequence of C{factor.variables(),factor} pairs
380 for each factor in the model
381 @rtype: List
382 """
383 itms = []
384 for hyperedge, factorlist in self._factors.items():
385 for factor in factorlist:
386 itms.append((hyperedge,factor))
387 return itms
388
390 """Display a GUI widget for displaying a model
391
392 @param parent: A widget into which the GUI is placed.
393 @type parent: Some suitable Tk object.
394 @param colours: Mapping from hyperedges to colours
395 @type colours: Dictionary
396 """
397 if colours is None:
398 colours = {}
399 gui = Tkinter.Frame(parent)
400 gui.pack()
401 top = Tkinter.Frame(gui)
402 top.pack()
403 fgs = []
404 for hyperedge, factor in self.items():
405 fgs.append(factor.gui_main(top,edit=False,bg=colours.get(hyperedge,'grey')))
406 fgs.append(Tkinter.Label(top,text='*'))
407 fgs.pop()
408 for widget in fgs:
409 widget.pack(side=Tkinter.LEFT)
410 bottom = Tkinter.Frame(gui)
411 bottom.pack()
412 for (txt,cmd) in [('Done',gui.destroy),
413 ('Quit', parent.destroy)]:
414 button = Tkinter.Button(bottom,text=txt,command=cmd)
415 button.bind('<Return>', lambda event: cmd())
416 button.pack(side=Tkinter.LEFT)
417
419 """
420 Return (a copy of) the hypergraph associated with the model
421
422 @return: (A copy of) the hypergraph associated with the model
423 @rtype: L{Hypergraph}
424 """
425 return self._hypergraph.copy()
426
428 """
429 Increment C{self} with counts directly from C{rawdata}
430
431 OK for parameter fitting not for structure learning
432 @param rawdata: A tuple like that returned by L{IO.read_csv}.
433 @type rawdata: Tuple
434 @raise IndexError: If C{self} has a variable missing from C{rawdata}
435 """
436 variables,records = rawdata[2:]
437 factor_variables_info = []
438 factors = self.factors()
439 for factor in factors:
440 factor_variables_info.append(factor.get_variables_info(variables))
441 for record in records:
442 for i, factor in enumerate(factors):
443 factor.inc_from_record(factor_variables_info[i],record)
444
446 """Return the interaction graph for a model
447
448 The interaction graph contains an edge for any pair of variables
449 which are members of a common factor.
450 @return: The interaction graph
451 @rtype: L{Graphs.UGraph}
452 """
453 return self._hypergraph.two_section()
454
456 """Whether the model is simple
457
458 @return: Whether the model is simple
459 @rtype: Boolean
460 """
461 return self._hypergraph.is_simple()
462
464 """Return a decomposable model using an elimination ordering
465
466 If no C{elimination_ordering} is given
467 L{Hypergraphs.ReducedHypergraph.maximum_cardinality_search} is used to
468 provide one.
469
470 Returned object shares many attributes with C{self}. Use a copy of self,
471 if it still needed as an independent object.
472
473 @param elimination_ordering: Order in which to eliminate variables
474 @type elimination_ordering: Sequence
475 @return: A decomposable model
476 @rtype: L{DFR}
477 """
478
479 dg, destination = self.hypergraph().make_decomposable2(elimination_ordering)
480 fs = {}
481 for new_hyperedge in dg:
482 fs[new_hyperedge] = Factor(new_hyperedge,domain=self)
483 for old_hyperedge, new_hyperedge in destination.items():
484 fs[new_hyperedge] *= self.factor(old_hyperedge)
485 dm = DFR()
486 dm.__dict__ = self.__dict__
487 dm._hypergraph = dg
488 dm._factors = fs
489 return dm
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
515 """Alter a model to represent the marginal distibution
516 on C{variables}
517
518 Marginal only represented up to normalisation
519
520 If C{self} is a BN and C{variables} is an ancestral set then C{self} remains a BN otherwise
521 it becomes a general hierarchical model.
522 """
523 self.sumout(self.variables().difference(variables))
524
526 """Return the marginal distibution
527 on C{variables} as a L{Parameters.Factor}
528
529 Marginal is properly normalised
530
531 """
532 cp = self.copy()
533 cp.sumout(self.variables().difference(variables))
534
535
536 return reduce(operator.mul, cp).normalised()
537
539 """Alter a factored representation by summing out variables
540
541 TODO: If we just want a marginal, no need to bother altering self
542
543 @param variables: Ordered variables to eliminate
544 @type variables: Sequence type
545 @param naive: If C{True} the variables are summed out in the order given
546 by C{variables}. If C{False} conditional independence and the number of factors
547 a variable is in is used to construct a more intelligent order.
548 @type naive: Boolean
549 @return: The altered C{self}
550 @rtype: Class of C{self}
551 """
552 if not naive:
553
554
555 insted = self._instd
556 remaining = self.variables() - set(variables)
557 hypergraph = self._hypergraph
558 reachable = set(hypergraph.reachable(remaining-insted,insted))
559 reachable.update(remaining)
560 for hyperedge in hypergraph.hyperedges():
561 if not (hyperedge & reachable):
562 self.remove(hyperedge)
563
564
565 variables = self.variables().intersection(variables)
566
567
568 def var_sort(v):
569 return self.num_factors_containing_variable(v)
570 variables = sorted(variables,key=var_sort)
571
572 for variable in variables:
573 self.eliminate_variable(variable)
574 return self
575
576 - def makeDN(self,allow_dummies=False):
577 """Make a dependency network where the CPT for each variable
578 is its distribution conditional on its Markov blanket
579 """
580 return DN([self.markov_blanket_cpt(v,allow_dummies) for v in self._variables])
581
583 """Return the Markov blanket for C{variable}
584
585 @param variable: Variable in the model
586 @type variable: Immutable
587 @return: Markov blanket for C{variable}
588 @rtype: Set
589 @raise KeyError: If C{variable} is not in the model
590 """
591 return self._hypergraph.neighbours(variable)
592
594 """Create a CPT for the distribution of C{variable}
595 conditional on its Markov blanket
596
597 @param variable: variable for which the conditional distribution is required
598 @type variable: Immutable
599 """
600 factor = 1
601 for hyperedge in self._hypergraph.star(variable):
602 factor *= self.factor(hyperedge)
603 return CPT(factor,variable,cpt_force=True,allow_dummies=allow_dummies)
604
606 """Return the number of distincet factors containing a given variable
607 """
608 return self._hypergraph.star_size(variable)
609
611 """Remove a factor or factors from a hierarchical model
612
613 All factors with variables C{hyperedge} will be removed
614
615 @param hyperedge: The variables of the factor
616 @type hyperedge: Iterable
617 @raise KeyError: If no factor with these variables exists
618 """
619 hyperedge = frozenset(hyperedge)
620 del self._factors[hyperedge]
621 self._hypergraph.remove_hyperedge(hyperedge)
622 self._variables = frozenset(self._hypergraph.vertices())
623
625 """Reduce the model, returning any distinct redundant hyperedges
626
627 Any factor all of whose variables are contained in another is removed
628 @return: Redundant hyperedges
629 @rtype: List
630 """
631 self.simplify()
632 reds = self._hypergraph.redundant_hyperedges()
633 for hyperedge, supersets in reds.items():
634 absorbing_factor = self.some_factor(supersets.pop())
635 absorbing_factor *= self.factor(hyperedge)
636 self.remove(hyperedge)
637 reds = reds.keys()
638 if emptyset in self._hypergraph:
639 reds.append(emptyset)
640 self.remove_hyperedge(emptyset)
641 return reds
642
644 """Test whether a model is reduced
645
646 @return: Whether a model is reduced
647 @rtype: Boolean
648 """
649 return self._hypergraph.is_reduced()
650
652 """Simplify the model
653
654 so that there are no factors with the same variable sets
655 """
656 for hyperedge, factorlist in self._factors.items():
657 nfl = factorlist[0]
658 for f in factorlist[1:]:
659 nfl *= f
660 self._factors[hyperedge] = [nfl]
661
663 """Return an arbitrary factor with variables C{hyperedge}
664
665 @param hyperedge: Variable set
666 @type hyperedge: Iterable
667 @return: Factor
668 @rtype: L{Parameters.Factor} object
669 @raise KeyError: If no factor with C{hyperedge} as varaibles exists
670 """
671 self._factors[frozenset(hyperedge)][0]
672
674 """Sum out (marginalise away) variables using maximum cardinality
675
676 C{variables} may be altered
677
678 @param variables: Variables to sum out
679 @type variables: Iterable
680 @return: The marginal model
681 @rtype: Same as C{self}
682 """
683 if not variables:
684 return
685 min = len(self._hypergraph)
686 for variable in variables:
687 n = self._hypergraph.star_size(variable)
688 if n == 1:
689 best = variable
690 break
691 else:
692 if n <= min:
693 min = n
694 best = variable
695 self.eliminate_variable(best)
696 if not isinstance(variables,set):
697 variables = set(variables)
698 variables.remove(best)
699 self.sumout(variables)
700 return self
701
702 variable_elimination = marginalise_away
703
705 """As L{variable_elimination} but also return the corresponding
706 L{Graphs.DiForest} object
707
708 If all variables are summed out then the relevant scalar is returned
709 @param variables: Ordered variables to eliminate
710 @type variables: Sequence type
711 @return: C{None} usually. The relevant scalar if all variables are
712 summed out
713 @rtype: C{None} or C{Float}
714 """
715 junction_forest = DiForest()
716 message_produced = {}
717 for variable in variables:
718 cluster, new_message, messages_used = self.eliminate_variable(variable,True)
719 junction_forest.add_vertex(cluster)
720 for old_cluster, old_message in message_produced.items():
721 if old_message in messages_used:
722 junction_forest.add_arrow(old_cluster,cluster)
723 del message_produced[old_cluster]
724 message_produced[cluster] = new_message
725 return junction_forest
726
727
729 """Return the sum of values associated with each full joint instantiation
730
731 This is the partition function. Result is generally not 1.0 since models may
732 represent distributions only up to normalisation. An expensive operation.
733 @return: The sum of values associated with each full joint instantiation
734 @rtype: Float
735 @todo: Use sensible ordering of variables to eliminate
736 """
737 tmp = self.copy()
738 tmp.variable_elimination(tmp._variables)
739 return tmp.factor(emptyset).z()
740
742 """Set all values in all factors to zero
743 """
744 for factor in self:
745 factor.zero()
746
760
762 """Add a factor to the _factors dictionary
763
764 @param hyperedge: The hyperedge for the factor
765 @type hyperedge: Frozenset
766 @param factor: The factor to be added
767 @type factor: L{Factor}
768 """
769 try:
770 self._factors[hyperedge].append(factor)
771 except KeyError:
772 self._factors[hyperedge] = [factor]
773
775 for variable, values in condition.items():
776 if variable not in self._variables:
777 raise KeyError('%s not a variable of this model' % variable)
778 for value in values:
779 if value not in self._domain[variable]:
780 raise ValueError(
781 "%s has values %s. '%s' is not one of them" %
782 (variable, tuple(self._domain[variable]), value))
783
785 """Put a factor in the _factors dictionary
786
787 Overwriting any previous entry
788
789 @param hyperedge: The hyperedge for the factor
790 @type hyperedge: Frozenset
791 @param factor: The factor to be put
792 @type factor: L{Factor}
793 """
794 self._factors[hyperedge] = [factor]
795
797 """Factored representations of probability distributions
798
799 Probability distributions represented by a set of factors.
800 The distribution is equal to the product of these factors (up to
801 normalisation).
802 Each factor is a L{Parameters.Factor}.
803 """
804 pass
805
807 """Factored representations of probability distributions (associated hypergraph is simple)
808
809 Probability distributions represented by a set of factors.
810 The distribution is equal to the product of these factors.
811 Each factor is a L{Parameters.Factor}.
812
813 No two factors have the same variables.
814 If factors were (A,B), (A,B), (B,C) then would NOT be simple.
815
816 @ivar _hypergraph: The hypergraph associated with the model. There is exactly
817 one hyperedge for each factor. The hyperedge for a given factor is just the set of that
818 factor's variables (implemented as a frozenset).
819 @type _hypergraph: L{SimpleHypergraph} object
820 @ivar _factors: A mapping from each hyperedge in the associated hypergraph to its associated factor,
821 who will have the hyperedge as its variables.
822 @type _factors: Dictionary
823 """
824 - def __init__(self,factors=(),domain=None,new_domain_variables=None,
825 must_be_new=False,check=False):
828
830 """Return an iterator over the factors in the model
831
832 To allow C{for factor in model: ...} constructions.
833 @return: An iterator over the factors in the model
834 @rtype: Iterator
835 """
836 return self._factors.itervalues()
837
839 """Multiply a simple hierarchical model by a factor
840
841 If a factor with the same variables already exists,
842 then this existing factor is multiplied by C{factor} to maintain
843 simplicity.
844
845 @param factor: The factor to be multiplied in
846 @type factor: L{Factor}
847 """
848 try:
849 self._factors[factor.variables()] *= factor
850 except KeyError:
851 FR._add_factor(self,factor)
852
853
855 """Add a factor to the _factors dictionary
856
857 @param hyperedge: The hyperedge for the factor
858 @type hyperedge: Frozenset
859 @param factor: The factor to be added
860 @type factor: L{Factor}
861 """
862 self._factors[hyperedge] = factor
863
864 _put_factor_to_dict = _add_factor_to_dict
865
866
867 - def copy(self,copy_domain=False):
868 """Return a deep copy of a hierarchical model
869
870 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy
871 shares C{self}'s domain
872 @type copy_domain: Boolean
873 @return: A copy of C{self}
874 @rtype: Same type as C{self}
875 """
876 cp = SubDomain.copy(self,copy_domain)
877 cp.__class__ = self.__class__
878 cp._hypergraph = self._hypergraph.copy()
879 factors = {}
880 for hyperedge, factor in self._factors.items():
881 fc = factor.copy()
882 cp.common_domain(fc)
883 factors[hyperedge] = fc
884 cp._factors = factors
885 return cp
886
888 return self._factors[frozenset(hyperedge)]
889
891 return self._factors.values()
892
894 return [self._factors[hyperedge] for
895 hyperedge in self._hypergraph[variable]]
896
899
901 return self._factors.items()
902
905
907 return self._factors[frozenset(hyperedge)]
908
910 """Factored representations of probability distributions (associated hypergraph is graphical)
911
912 Factors are defined exactly on the cliques of its interaction graph.
913 If factors were (A,B), (A,C), (B,C) then would NOT be graphical.
914
915 @todo: Actually implement methods specific to this class. At present L{GFR} objects
916 behave identically to L{FR} objects so this class only useful as a label I{claiming}
917 that a hierarchical model is graphical.
918 """
919 pass
920
922 """Factored representations of probability distributions (associated hypergraph is reduced)
923
924 A representation with no redundant factors.
925 If factors were (A), (A,B) then would NOT be reduced.
926
927 @todo: Actually implement methods specific to this class. At present L{RFR} objects
928 behave identically to L{SFR} objects so this class only useful as a label I{claiming}
929 that a hierarchical model is reduced.
930 """
931
932 - def __init__(self,hm=(),check=True,modify=False):
933 """Construct RFR from existing hierarchical model"""
934 if not isinstance(hm,FR):
935 hm = SFR(hm)
936 if isinstance(hm,RFR):
937 check = False
938 modify = False
939 rhg = ReducedHypergraph(hm._hypergraph,check,modify,trace=True)
940 if modify:
941 for redund, supersets in rhg.trace.items():
942 superset = member(supersets)
943 hm._factors[superset] *= hm._factors[redund]
944 del hm._factors[redund]
945 del rhg.trace
946 if isinstance(hm,BN):
947 del hm._adg
948 self.__dict__ = hm.__dict__
949
951 """Multiply a reduced factored representation by a factor
952
953 If a factor containing all C{factor}'s variables exists then
954 C{factor} is multiplied into it.
955
956 @param factor: The factor to be multiplied in
957 @type factor: L{Factor}
958 @todo: implement properly
959 """
960 hyperedge = factor.variables()
961 smallerhes = []
962 for he in self._hypergraph:
963 if not smallerhes and he >= hyperedge:
964 self._factors[he] *= factor
965 return
966 elif he <= hyperedge:
967 smallerhes.append(he)
968 SFR._add_factor(self,factor)
969 for he in smallerhes:
970 self._factor[hyperedge] *= self._factor[he]
971 self.remove(he)
972
973 - def ipf(self,marginals,epsilon=0.001):
974 """Iterative proportional fitting
975
976 Iteratively alters the marginals of C{self} until
977 they (almost) equal those supplied by C{marginals}.
978 The existing factors in C{self} act only as a starting
979 point for the iterative algorithm.
980 Iteration stops once there is no marginal probability in C{self}
981 differing by more than C{epsilon} from the one supplied by C{marginals}.
982
983 Each factor in C{self} must have exactly one corresponding marginal
984 in C{marginals}.
985
986 Typically C{marginals} will come from an empirical distribution, ie
987 they will just be normalised counts from some data set.
988
989 @param marginals: Marginals to fit
990 @type marginals: Dictionary mapping hyperedges to marginals
991 @param epsilon: Convergence criterion
992 @type epsilon: Float
993 """
994 converged = False
995 while not converged:
996 converged = True
997 for hyperedge, factor in self._factors.items():
998 current_marginal = self.marginal_factor(hyperedge)
999 desired_marginal = marginals[hyperedge]
1000 factor *= desired_marginal / current_marginal
1001 if converged and current_marginal.differ(desired_marginal,epsilon):
1002 converged = False
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026 -class DFR(GFR):
1027 """Factored representations of probability distributions (associated hypergraph is decomposable)
1028
1029 Interaction graph is decomposable (triangulated)
1030
1031 @note: There are no methods specific to this class. At present L{DFR} objects
1032 behave identically to L{FR} objects so this class is only useful as a label I{claiming}
1033 that a hierarchical model is decomposable. However there are methods for L{JFR} objects
1034 which are just (reduced) decomposable models with a particular join forest specified.
1035 """
1036 pass
1037
1038
1039 -class RDFR(RFR,DFR):
1040 """Factored representations of probability distributions (associated hypergraph is decomposable and reduced)
1041
1042 Interaction graph is decomposable (triangulated)
1043
1044 @note: There are no methods specific to this class. At present L{RDFR} objects
1045 behave identically to L{FR} objects so this class is only useful as a label I{claiming}
1046 that a hierarchical model is decomposable and reduced. However there are methods for L{JFR} objects
1047 which are just reduced decomposable models with a particular join forest specified.
1048 """
1049 pass
1050
1051
1052 -class JFR(RDFR):
1053 """Join forest representations of probability distributions
1054
1055 A reduced, decomposable model where the factors are clique potentials and
1056 separators, related by a join forest.
1057
1058 @ivar _hypergraph: The join forest associated with the model.
1059 @type _hypergraph: L{Hypergraphs.ReducedJoinForest}
1060 @ivar _factors: Maps each clique (node of the join_forest) to its associated factor
1061 @type _factors: Dictionary
1062 @ivar _separators: Maps each separator (edge of the join_forest) to its associated factor
1063 @type _separators: Dictionary
1064 """
1065
1066 - def __init__(self,hm=(),domain=None,new_domain_variables=None,
1067 must_be_new=False,check=False,modify=False,elimination_order=None):
1068 """Construct a join forest model from an existing hierarchical model (or factors)
1069
1070 If C{hm} is a hypergraph it will end up with identical attributes
1071 to C{self}, so typically a copy of an existing hypergraph is used.
1072
1073 @param hm: Hierachical model or sequence of factors
1074 @type hm: L{FR} or sequence
1075 @param domain: B{Only used if C{hm} is not an L{FR} object.} A domain for the model.
1076 If None the internal default domain is used.
1077 @type domain: L{Variables.Domain} or None
1078 @param new_domain_variables: B{Only used if C{hm} is not an L{FR} object.} A dictionary
1079 containing a mapping from any new
1080 variables to their values. C{domain} is updated with these values
1081 @type new_domain_variables: Dict or None
1082 @param must_be_new: B{Only used if C{hm} is not an L{FR} object.} Whether domain
1083 variables in C{new_domain_variables} have to be new
1084 @type must_be_new: Boolean
1085 @param check: B{Only used if C{hm} is not an L{FR} object.} Whether to check
1086 that all variables exist in C{domain}
1087 @type check: Boolean
1088 @param modify: Whether to modify C{hm} to make it decomposable.
1089 @type modify: Boolean
1090 @param elimination_order: If supplied
1091 and C{modify=True}, the elimination order to use to make the C{hm}
1092 decomposable. (If not supplied maximum cardinality search is used to generate an order.)
1093 @type elimination_order: Sequence
1094 @raise VariableError: If a variable in C{new_domain_variables}
1095 already exists with values different from
1096 its values in C{new_domain_variables};
1097 Or if C{must_be_new} is set and the variable already exists.
1098 Or if C{check} is set and a variable in C{variables} is not in the domain
1099 @raise DecomposabilityError: If C{modify=False} and C{hm} is not decomposable.
1100 """
1101 if not isinstance(hm,FR):
1102 hm = FR(hm,domain,new_domain_variables,must_be_new,check)
1103 join_forest = ReducedJoinForest(hm._hypergraph,modify,True,elimination_order)
1104 if modify:
1105 fs = {}
1106 for hyperedge in join_forest:
1107 fs[hyperedge] = Factor(hyperedge,domain=hm)
1108 for old_hyperedge, hyperedge in join_forest.trace.items():
1109 fs[hyperedge] *= hm._factors[old_hyperedge]
1110 hm._factors = fs
1111 del join_forest.trace
1112 seps = {}
1113 for clique1, clique2 in join_forest._uforest.lines():
1114 seps[frozenset([clique1,clique2])] = Factor(clique1 & clique2,
1115 domain=hm)
1116 hm._hypergraph = join_forest
1117 hm._separators = seps
1118 self.__dict__ = hm.__dict__
1119
1121 return 'Cliques:\n%sSeparators:\n%sJoin Forest:\n%s' % (
1122 RFR.__str__(self),
1123 FR(self._separators.values()),
1124 self._hypergraph)
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1136 """Alter a JFR so that the factors associated with both cliques and
1137 separators are the appropriate marginal distributions.
1138
1139 In each tree in the forest, messages are first passed to a root and then
1140 out again.
1141 """
1142 jf = self._hypergraph.join_forest()
1143 edges = jf.ordered_edges_towards_root()
1144 for (frm,to) in edges:
1145 self.send_message(frm,to)
1146 for (frm,to) in reversed(edges):
1147 self.send_message(to,frm)
1148
1149
1169
1170 - def copy(self,copy_domain=False):
1171 """Return a deep copy of a JFR
1172
1173 @param copy_domain: If true C{self}'s domain is copied, otherwise the copy
1174 shares C{self}'s domain
1175 @type copy_domain: Boolean
1176 @return: A copy of C{self}
1177 @rtype: L{JFR}
1178 """
1179 cp = FR.copy(self,copy_domain)
1180 seps = {}
1181
1182 for hyperedge, sep in self._separators.items():
1183 sc = sep.copy()
1184 cp.common_domain(sc)
1185 seps[hyperedge] = sc
1186 cp._separators = seps
1187 return cp
1188
1190 clique_win = Tkinter.Frame(parent)
1191 clique_win.pack()
1192 self._clique_disp(clique_win,self._factors)
1193 sep_win = Tkinter.Frame(parent)
1194 sep_win.pack()
1195 self._clique_disp(sep_win,self._separators)
1196 gc = GraphCanvas(self._hypergraph._uforest,parent,edit=False,pp_vertex=pretty_str_set,
1197 colour_user_actions=False)
1198 gc.pack()
1199 bottom_win = Tkinter.Frame(parent)
1200 bottom_win.pack()
1201 perfect_sequence = self._hypergraph.perfect_sequence()
1202 cpps = perfect_sequence[:]
1203 banned = set()
1204 updone=[False]
1205 def handler(self=self,cliques=cpps,banned=banned,clique_win=clique_win,
1206 sep_win=sep_win,updone=updone,gc=gc):
1207 if not cliques:
1208 if updone[0]:
1209 print 'All messages sent and received'
1210 return
1211 cliques[:] = perfect_sequence
1212 cliques.reverse()
1213 banned.clear()
1214 updone[0] = True
1215 print 'ROOT has received all messages'
1216 clique = cliques.pop()
1217 for nbr in self._hypergraph.clique_neighbours(clique,banned):
1218 self.send_message(clique,nbr)
1219 print 'Sent from %s to %s' % (clique,nbr)
1220 for clique2 in self._hypergraph:
1221 gc.vertex_config(clique2,fill='black')
1222 gc.vertex_config(clique,fill='blue')
1223 gc.vertex_config(nbr,fill='red')
1224 banned.add(clique)
1225 self._clique_disp(clique_win,self._factors)
1226 self._clique_disp(sep_win,self._separators)
1227 Tkinter.Button(bottom_win,text='Next',command=handler).pack()
1228
1229 - def marginal(self,variables,root=None,perfect_sequence=None):
1230 """Compute the marginal distribution of C{variables} B{assuming they
1231 are a subset of some hyperedge (ie clique)}
1232
1233 Just chooses an appropriate hyperedge as a root and sends messages to it.
1234 Typically used just after C{self} has been altered so that, at least,
1235 this root clique has the correct marginal.
1236
1237 Used in iterative proportional fitting.
1238
1239 @param variables: The variables for which a marginal is sought.
1240 @type variables: Iterable
1241 @param root: If supplied, a clique which is simply B{assumed} to contain
1242 the variables. If not supplied one is computed.
1243 @type root: Frozenset
1244 @param perfect_sequence: If supplied an ordering B{assumed} to be a perfect
1245 sequence with C{root} as its first element.
1246 @type perfect_sequence: Sequence
1247 @raise ValueError: If no clique contains C{variables}
1248 """
1249 variables=frozenset(variables)
1250 if root is None:
1251 for v in variables:
1252 try:
1253 containers &= self._hypergraph.star(v)
1254 except NameError:
1255 containers = self._hypergraph.star(v)
1256 if not containers:
1257 raise ValueError("No clique contains %s")
1258 root = member(containers)
1259 if perfect_sequence is None:
1260 perfect_sequence = self.perfect_sequence(root)
1261 self.send_messages(perfect_sequence)
1262 return self._factors[root].copy().marginalise_onto(variables)
1263
1265 """Send a message from a clique to a neighbouring clique
1266
1267 @param frm: The clique sending the message
1268 @type frm: Frozenset (hyperedge)
1269 @param to: The clique receiving the message
1270 @type to: Frozenset (hyperedge)
1271 @raise KeyError: If C{frm} and C{to} are not neighbours in the join
1272 forest
1273 """
1274 frm_marginal = self._factors[frm].sumout(frm - to)
1275 edge = frozenset([frm,to])
1276 self._factors[to] *= (frm_marginal/self._separators[edge])
1277 self._separators[edge] = frm_marginal
1278
1280 """Send messages using a fixed ordering of cliques
1281
1282 C{cliques} acts as a I{stack}. The first clique which is allowed
1283 to send a message is the I{last} element of the list C{cliques}.
1284 Thus the root is the first element.
1285 (This is just a little quicker than starting with the first.)
1286 A clique can only send a message to a clique 'underneath' it in the stack.
1287
1288 No message to a clique in C{banned} is allowed to be sent.
1289 @param cliques: Stack of cliques
1290 @type cliques: List
1291 @param banned: Cliques which are banned from receiving messages
1292 @type banned: set
1293 """
1294 if banned is None:
1295 banned = set()
1296 if cliques:
1297 clique = cliques.pop()
1298 for nbr in self._hypergraph.clique_neighbours(clique,banned):
1299 self.send_message(clique,nbr)
1300 banned.add(clique)
1301 self.send_messages(cliques,banned)
1302
1304 """Return (separator,separator_factor) pairs for all separators
1305
1306 The order is arbitrary.
1307
1308 @return: (separator,separator_factor) pairs for all separators
1309 @rtype: List
1310 """
1311 return self._separators.items()
1312
1313
1315 """Return an iterator over the factors associated with separators in the model
1316
1317 The order is arbitrary.
1318 @return: An iterator over the factors associated with separators in the model
1319 @rtype: Iterator
1320 """
1321 return self._separators.itervalues()
1322
1324 return self._hypergraph.trace
1325
1337
1339 for child in clique_win.winfo_children():
1340 child.destroy()
1341 fgs = []
1342 for hyperedge, factor in factor_dict.items():
1343 fgs.append(factor.gui_main(clique_win,edit=False))
1344 fgs.append(Tkinter.Label(clique_win,text='*'))
1345 fgs.pop()
1346 for widget in fgs:
1347 widget.pack(side=Tkinter.LEFT)
1348
1349
1350 -class DN(SFR):
1351 """Dependency network representations of probability distributions
1352
1353 A SFR where the factors consist of a CPT
1354 for each variable. And where the associated digraph B{need not be acyclic}.
1355
1356 @ivar _adg: digraph, not necessarily acyclic despite the confusing name (!), giving the DN's structure
1357 @type _adg: L{DiGraph} object
1358 """
1359
1360 - def __init__(self,factors=(),domain=None,new_domain_variables=None,
1361 must_be_new=False,check=False,dg=None):
1362 """DN constructor.
1363 Each factor has its domain expanded to be the domain of the model, if necessary.
1364
1365 If C{factors} is a sequence of factors then each is merely
1366 B{assumed} to be a CPT object, but there is no check done.
1367
1368 @param factors: The CPTs in the dependency network.
1369 (Alternatively an existing object of class L{FR} (or its subclasses), in
1370 which case C{self} has identical attributes, except possibly the attribute specifying its
1371 C{adg}.)
1372 @type factors: Sequence, each element of which is a L{Parameters.Factor} or
1373 L{Variables.SubDomain} object. (Alternatively an object of class L{FR}
1374 or one of its subclasses.)
1375 @param domain: A domain for the model.
1376 If None the internal default domain is used.
1377 @type domain: L{Variables.Domain} or None
1378 @param new_domain_variables: A dictionary containing a mapping from any new
1379 variables to their values. C{domain} is updated with these values
1380 @type new_domain_variables: Dict or None
1381 @param must_be_new: Whether domain variables in C{new_domain_variables} have
1382 to be new
1383 @type must_be_new: Boolean
1384 @param check: Whether to check that all variables exist in C{domain}
1385 @type check: Boolean
1386 @param dg: The digraph for the DN. This is simply B{assumed} to
1387 be the correct digraph, no check is made. If C{dg} is not supplied the correct
1388 digraph is created.
1389 @type dg: L{Graphs.DiGraph}
1390 @raise VariableError: If a variable in C{new_domain_variables}
1391 already exists with values different from
1392 its values in C{new_domain_variables};
1393 Or if C{must_be_new} is set and the variable already exists.
1394 Or if C{check} is set and a variable in C{variables} is not in the domain
1395 @raise AttributeError: If C{adg} was not supplied and any of the supplied
1396 factors are not L{Parameters.CPT} objects.
1397 """
1398 SFR.__init__(self,factors,domain,new_domain_variables,must_be_new,check)
1399 if dg is None:
1400 dg = DiGraph()
1401 for cpt in self:
1402 dg.put_family(cpt.child(),cpt.parents())
1403 self._adg = dg
1404
1406 """Return the CPT (not a copy) corresponding to C{key}
1407
1408 C{key} can be all variables in
1409 CPT or just the child variable. The former is quicker.
1410 @param key: Family or child variable
1411 @type key: Iterable
1412 @return: The CPT for var
1413 @rtype: L{CPT}
1414 @raise KeyError: If there is no corresponding CPT.
1415 """
1416 if isinstance(key,str):
1417 return self._factors[frozenset([key]) | self._adg.parents(key)]
1418 else:
1419 return SFR.__getitem__(self,key)
1420
1422 """Formal string representation of a DN
1423
1424 @return: Formal string representation of a DN
1425 @rtype: String
1426 """
1427 cpts = ','.join([cpt.repr_nodomain() for cpt in self._factors.values()])
1428 dkt = dict([(v,self._domain[v]) for v in self._variables])
1429 return 'DN([%s],Domain(new_domain_variables=%s),dg=%s)' % (cpts,dkt,repr(self._adg))
1430
1431
1433 """Print each CPT separately
1434
1435 Use lexicographical order on child name
1436
1437 @return: Pretty representation of a L{BN}
1438 @rtype: String
1439 """
1440 out = ''
1441 tmp = []
1442 for cpt in self._factors.values():
1443 tmp.append((cpt.child(),cpt))
1444 tmp.sort()
1445 for l, cpt in tmp:
1446 out += str(cpt)
1447 out += '\n'
1448 return out
1449
1451 """Adds new CPTs to a BN (or DN), if possible
1452
1453 The list C{cpts} is destroyed by this process
1454
1455 @param cpts: CPTs to add
1456 @type cpts: list
1457 @raise AttributeError: If any of the C{cpts} is not a L{Parameters.CPT} object
1458 @raise ValueError: If by adding these CPTs C{self} would no longer be a BN (or DN).
1459 """
1460 while cpts:
1461 for i, cpt in enumerate(cpts):
1462 if (cpt.child() not in self._variables
1463 and cpt.parents() <= self._variables):
1464 self *= cpt
1465 for parent in cpt.parents():
1466 self._adg.add_arrow(parent,cpt.child())
1467 del cpts[i]
1468 break
1469 else:
1470 raise ValueError('Could not add CPTs: %s' % cpts)
1471
1472
1474 """
1475 Return the BDeu score of C{self} (and its component scores) on C{data}.
1476
1477 @param data: The data
1478 @type data: L{Parameters.CompactFactor}
1479 @param precision: Prior precision
1480 @type precision: Float
1481 @return: C{(score,variable_scores)} where C{score} is the BDeu score and
1482 C{variable_scores} is a list of component scores, one for each variable (in order)
1483 in C{self}.
1484 @rtype: Tuple
1485 @raise NameError: If L{gPyC} was not successfully imported
1486 """
1487 variable_scores = []
1488 score = 0.0
1489 for variable in self.variables():
1490 new_score = self[variable].get_counts(data).bdeu_score(precision)
1491 score += new_score
1492 variable_scores.append(new_score)
1493 return score, variable_scores
1494
1496 """
1497 Return the BDeu score of C{self} (and its component scores) on C{data}
1498 using the existing BDeu score (and component scores) for C{other}
1499 computed on C{data} previously.
1500
1501 Clearly, for this to give the right answer, the existing score must have used the same C{precision}.
1502
1503 @param data: The data
1504 @type data: L{Parameters.CompactFactor}
1505 @param other: Previously scored BN
1506 @type other: L{BN}
1507 @param other_score: The BDeu score of C{other} on C{data}
1508 @type other_score: Float
1509 @param other_variable_scores: For each variable (in order) in C{other}, the
1510 component of the BDeu score for that variable
1511 @type other_variable_scores: List
1512 @param precision: Prior precision
1513 @type precision: Float
1514 @return: C{(score,variable_scores)} where C{score} is the BDeu score and
1515 C{variable_scores} is a list of component scores, one for each variable in
1516 C{self}.
1517 @rtype: Tuple
1518 @raise NameError: If L{gPyC} was not successfully imported
1519 """
1520 variable_scores = other_variable_scores
1521 score = other_score
1522 for i, variable in enumerate(self.variables()):
1523 if self.adg.parents(variable) != other.adg.parents(variable):
1524 new_score = self[variable].get_counts(data).bdeu_score(precision=1.0)
1525 score += (new_score - other_variable_scores[i])
1526 variable_scores[i] = new_score
1527 return score, variable_scores
1528
1530 """Return the children of C{variable}
1531
1532 @param variable: Variable in the BN (or DN)
1533 @type variable: Immutable (usually a string)
1534 @return: The children of C{variable}
1535 @rtype: set
1536 @raise KeyError: If C{variable} is not in the BN (or DN).
1537 """
1538 return self._adg.children(variable)
1539
1540 - def gibbs_sample(self,evidence=None,burnin=100,iterations=1000,inst=None,order=None):
1541 """
1542 Run Gibbs sampling
1543 """
1544 if evidence is None:
1545 evidence = {}
1546
1547 parent_indices, cpts, order = self._info_for_sampling_dn(evidence,order)
1548
1549 if inst is None:
1550 inst = [member(self.values(v)) for v in self.variables()]
1551
1552 indices = []
1553 uninstantiated = set()
1554 for i, v in enumerate(order):
1555 indices.append(i)
1556 if v in evidence:
1557 inst[i] = evidence[v]
1558 else:
1559 uninstantiated.add(i)
1560
1561 t = 0
1562 while t < burnin:
1563 for i in indices:
1564 parent_inst = tuple([inst[j] for j in parent_indices[i]])
1565 if i in uninstantiated:
1566 inst[i] = cpts[i].sample(parent_inst)
1567 t += 1
1568
1569 t = 0
1570 while t < iterations:
1571 for i in indices:
1572 parent_inst = tuple([inst[j] for j in parent_indices[i]])
1573 if i in uninstantiated:
1574 inst[i] = cpts[i].sample(parent_inst)
1575 yield inst
1576 t += 1
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1614 """Return the parents of C{variable}
1615
1616 @param variable: Variable in the BN (or DN)
1617 @type variable: Immutable (usually a string)
1618 @return: The parents of C{variable}
1619 @rtype: set
1620 @raise KeyError: If C{variable} is not in the BN (or DN).
1621 """
1622 return self._adg.parents(variable)
1623
1625 if evidence is None:
1626 evidence = {}
1627 if order is None:
1628 order = tuple(self._variables)
1629 dkt = dict(zip(order,range(len(order))))
1630 parent_indices = []
1631 cpts = []
1632 for variable in order:
1633 cpt = self[variable]
1634 parent_indices.append([dkt[p] for p in sorted(cpt.parents())])
1635 if variable not in evidence:
1636 cpt.initialise_sampler()
1637 cpts.append(cpt)
1638
1639 return parent_indices, cpts, order
1640
1642 """Bayesian network representations of probability distributions
1643
1644 A SFR where the factors consist of a CPT
1645 for each variable. And where the associated digraph B{must be acyclic}.
1646
1647 Joint is thus exactly the product of the factors
1648 @ivar _adg: The acyclic digraph giving the BN's structure
1649 @type _adg: L{ADG} object
1650 """
1651
1652 - def __init__(self,factors=(),domain=None,new_domain_variables=None,
1653 must_be_new=False,check=False,adg=None):
1654 """BN constructor.
1655 Each factor has its domain expanded to be the domain of the model, if necessary.
1656
1657 If C{factors} is a sequence of factors then each is merely
1658 B{assumed} to be a CPT object, but there is no check done.
1659
1660 @param factors: The CPTs in the Bayesian network.
1661 (Alternatively an existing object of class L{FR} (or its subclasses), in
1662 which case C{self} has identical attributes, except possibly the attribute specifying its
1663 C{adg}.)
1664 @type factors: Sequence, each element of which is a L{Parameters.Factor} or
1665 L{Variables.SubDomain} object. (Alternatively an object of class L{FR}
1666 or one of its subclasses.)
1667 @param domain: A domain for the model.
1668 If None the internal default domain is used.
1669 @type domain: L{Variables.Domain} or None
1670 @param new_domain_variables: A dictionary containing a mapping from any new
1671 variables to their values. C{domain} is updated with these values
1672 @type new_domain_variables: Dict or None
1673 @param must_be_new: Whether domain variables in C{new_domain_variables} have
1674 to be new
1675 @type must_be_new: Boolean
1676 @param check: Whether to check that all variables exist in C{domain}
1677 @type check: Boolean
1678 @param adg: The acyclic digraph for the BN. This is simply B{assumed} to
1679 be the correct ADG, no check is made. If C{adg} is not supplied the correct ADG
1680 is created if possible, otherwise a C{DirectedCycleError} is raised.
1681 @type adg: L{Graphs.ADG}
1682 @raise VariableError: If a variable in C{new_domain_variables}
1683 already exists with values different from
1684 its values in C{new_domain_variables};
1685 Or if C{must_be_new} is set and the variable already exists.
1686 Or if C{check} is set and a variable in C{variables} is not in the domain
1687 @raise AttributeError: If C{adg} was not supplied and any of the supplied
1688 factors are not L{Parameters.CPT} objects.
1689 @raise DirectedCycleError: If C{adg} was not supplied and there is no ADG
1690 consistent with the CPTs provided.
1691 """
1692 SFR.__init__(self,factors,domain,new_domain_variables,must_be_new,check)
1693 if adg is None:
1694 adg = ADG()
1695 for cpt in self:
1696 adg.put_family(cpt.child(),cpt.parents())
1697 self._adg = adg
1698
1699
1701 """Formal string representation of a BN
1702
1703 @return: Formal string representation of a BN
1704 @rtype: String
1705 """
1706 cpts = ','.join([cpt.repr_nodomain() for cpt in self._factors.values()])
1707 dkt = dict([(v,self._domain[v]) for v in self._variables])
1708 return 'BN([%s],Domain(new_domain_variables=%s),adg=%s)' % (cpts,dkt,repr(self._adg))
1709
1710
1711
1713 """Include a factor in a BN
1714
1715 If the factor is a CPT, its child is a new variable and all
1716 its parents are existing variables, then C{self}
1717 remains a BN, otherwise it becomes a L{FR} object
1718 """
1719 if (isinstance(factor,CPT) and
1720 factor.child() not in self._variables and
1721 factor.parents() <= self._variables):
1722 self._adg.put_family(factor.child(),factor.parents())
1723 else:
1724 self.__class__ = FR
1725 del self._adg
1726 FR._add_factor(self,factor)
1727
1729 """Return the ADG associated with BN
1730
1731 A copy of the BN's ADG is returned.
1732 @return: The ADG associated with BN
1733 @rtype: L{Graphs.ADG}
1734 """
1735 return self._adg.copy()
1736
1737 - def condition(self,condition,keep_class=False):
1738 """Alter a Bayesian network by effecting the restriction on
1739 variables given by C{condition}
1740
1741 Generally conditioning a BN produces something which is not
1742 a BN, see C{keep_class} argument.
1743
1744 This alters the model's domain. Make a copy with C{copy_domain=True}
1745 if the original domain will be needed
1746 @param condition: Dictionary of the form {var1:values1,var2:values2..}
1747 Each value of this dictionary must be an iterable
1748 @type condition: Dict
1749 @param keep_class: If C{False}, the object will cease to be a
1750 L{BN} object, becoming a {SFR} object, and all its factors will become
1751 L{Parameters.Factor} object. If C{True} remains a BN.
1752 @type keep_class: Boolean
1753 @return: The conditioned model
1754 @rtype: Same as C{self}
1755 @raise KeyError: If a variable is used that is not in the model
1756 @raise ValueError: If a value is used that is not a possible value of
1757 the variable it is attached to
1758 """
1759
1760 SFR.condition(self,condition,keep_class)
1761 if not keep_class:
1762 self.__class__ = SFR
1763 return self
1764
1765 - def copy(self,copy_domain=False):
1766 """
1767 Return a copy
1768
1769 @return: A copy of C{self}
1770 @rtype: L{BN}
1771 """
1772 cp = SFR.copy(self,copy_domain)
1773 cp._adg = self._adg.copy()
1774 return cp
1775
1777 """Alter a Bayesian network by summing out a variable
1778
1779 C{self} will cease to be a L{BN} object unless C{variable}
1780 has no children.
1781 @param variable: The variable to eliminate
1782 @type variable: String
1783 """
1784 if not self._adg.children(variable):
1785 if trace:
1786 return SFR.eliminate_variable(self,variable,True)
1787 self.remove(variable)
1788 else:
1789 self.__class__ = SFR
1790 return self.eliminate_variable(variable,trace)
1791
1793 """Return the hyperedge corresponding to the CPT with C{child}
1794 as child
1795
1796 @param child: Variable
1797 @type child: Immutable
1798 @return: Hyperedge corresponding to the CPT with C{child}
1799 as child
1800 @rtype: Frozenset
1801 """
1802 return frozenset([child]) | self._adg.parents(child)
1803
1805 order = self.topological_order()
1806 dkt = dict(zip(order,range(len(order))))
1807 parent_indices = []
1808 cpts = []
1809 for variable in order:
1810 cpt = self[variable]
1811 parent_indices.append([dkt[p] for p in sorted(cpt.parents())])
1812 if evidence is None or variable not in evidence:
1813 cpt.initialise_sampler()
1814 else:
1815 ev_val = evidence[variable]
1816 for i, val in enumerate(sorted(cpt.values(variable))):
1817 if val == ev_val:
1818 break
1819 tmp = {}
1820 itr = cpt.parent_insts_data()
1821 for parent_inst in cpt.parent_insts():
1822 tmp[tuple(parent_inst)] = itr.next()[i]
1823 cpt = tmp
1824 cpts.append(cpt)
1825
1826
1827 return parent_indices, cpts, order
1828
1830
1831 parent_indices, cpts, order = self._info_for_sampling(evidence)
1832
1833 indices = []
1834 uninstantiated = set()
1835 inst = []
1836 for i, v in enumerate(order):
1837 indices.append(i)
1838 if v in evidence:
1839 inst.append(evidence[v])
1840 else:
1841 uninstantiated.add(i)
1842 inst.append(None)
1843
1844 t = 0
1845 while t < iterations:
1846 weight = 1
1847 for i in indices:
1848 parent_inst = tuple([inst[j] for j in parent_indices[i]])
1849 if i in uninstantiated:
1850 inst[i] = cpts[i].sample(parent_inst)
1851 else:
1852 weight *= cpts[i][parent_inst]
1853 yield weight, inst
1854 t += 1
1855
1857
1858 parent_indices, cpts, order = self._info_for_sampling()
1859
1860 n = len(order)
1861 indices = range(n)
1862 inst = [None] * n
1863 t = 0
1864 while t < iterations:
1865 for i in indices:
1866 inst[i] = cpts[i].sample(
1867 tuple([inst[j] for j in parent_indices[i]]))
1868 yield inst
1869 t += 1
1870
1890
1892 """Add CPTs from a C{.dnet} file
1893
1894 @param info: Output from L{IO.read_dnet}
1895 @type info: Tuple
1896
1897 """
1898 try:
1899 dnet_variables, named_cpts = info
1900 except ValueError:
1901 dnet_variables, named_cpts, coords = info
1902 cpts = []
1903 for variable, (parents,data) in named_cpts.items():
1904 cpts.append(CPT(
1905 Factor(parents+[variable],data,self,
1906 new_domain_variables=dnet_variables,
1907 check=True),variable,cpt_check=True))
1908 self.add_cpts(cpts)
1909 try:
1910 self._adg.set_vertex_positions(coords)
1911 except NameError:
1912 pass
1913
1914
1915
1917 """Remove a CPT where C{key} is either the child or a tuple giving
1918 the family for the CPT.
1919
1920 Using the child as key is faster.
1921 @raise TypeError: If C{key} is neither a variable or a family
1922 """
1923 key_hyperedge = frozenset(key)
1924 if key_hyperedge in self._factors:
1925 hyperedge = key_hyperedge
1926 for child in hyperedge:
1927 parents = self._adg.parents(child)
1928 if hyperedge == parents | set([child]):
1929 break
1930 else:
1931 child = key
1932 hyperedge = self.family_hyperedge(child)
1933 if not self._adg.children(child):
1934 self._adg.remove_vertex(child)
1935 else:
1936
1937 self.__class__ = SFR
1938 del self._adg
1939 SFR.remove(self,hyperedge)
1940
1941
1942 - def sample(self,fobj,samples=100):
1943 """Write a sample in CSV form to C{fobj}
1944
1945 Uses forward sampling, so assumes no instantiations
1946
1947 @param fobj: Writable file
1948 @type fobj: File
1949 @param samples: How may samples to produce
1950 @type samples: Int
1951 """
1952 from gPy.Samplers import BNSampler
1953 for v in self._variables:
1954 print >>fobj, '%s:%s' % (v, ','.join(self._domain[v]))
1955 sampler = BNSampler(self)
1956 print >>fobj, ','.join(sampler.variables())
1957 while samples:
1958 print >>fobj, ','.join(sampler.forward_sample())
1959 samples -= 1
1960
1961 - def sample_sqlite(self,dbfilename=':memory',table='data',samples=100):
1962 """Write a sample to a sqlite database
1963
1964 Uses forward sampling, so assumes no instantiations
1965
1966 @param dbfilename: Name of sqlite database
1967 @type dbfilename: String
1968 @param samples: How may samples to produce
1969 @type samples: Int
1970 @return: sqlite database
1971 @rtype: L{Parameters.SqliteFactor} object
1972 """
1973 from gPy.Samplers import BNSampler
1974 from gPy.Parameters import SqliteFactor
1975 sqlfactor = SqliteFactor(self.variables(),dbfilename,table,domain=self)
1976 newself = self.copy(True)
1977 dom = newself._domain
1978
1979 for v, vals in dom.items():
1980 dom[v] = range(len(vals))
1981 sampler = BNSampler(newself)
1982 dkt = {}
1983 while samples:
1984 inst = tuple(sampler.forward_sample())
1985 try:
1986 dkt[inst] += 1
1987 except KeyError:
1988 dkt[inst] = 1
1989 sqlfactor.populate([inst+(count,) for inst,count in dkt.items()],
1990 sampler.variables())
1991 return sqlfactor
1992
1993
1995 """A topological ordering of the associated acyclic digraph
1996
1997 Children come after their parents in the ordering
1998 @return: A topological ordering of the vertices
1999 @rtype: List
2000 """
2001 return self._adg.topological_order()
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062 -class CBN(BN):
2063 - def __init__(self,factors=(),domain=None,adg=None):
2065
2066 @staticmethod
2068 """Construct a L{CBN} from a L{BN}."""
2069 self = BN()
2070 self.__dict__ = bn.__dict__
2071 self.__class__ = CBN
2072 self._adg = MutilatedADG.from_adg(self._adg)
2073 return self
2074
2075 @staticmethod
2077 """Construct a L{CBN} from an L{ADG} and estimates of the parameters from
2078 some observations.
2079 @type adg: L{ADG}
2080 @type data: L{CausalWorld}
2081 @param prior: the Dirichlet prior parameter (the same parameter value
2082 is used for all instances!) Note there may be some problems with
2083 this method: a B{different} prior is used by the BDeu score. However,
2084 in practice, for parameter estimation, this prior method seems to be ok.
2085 I was lazy and it was simple to implement (cb). If prior is zero, then
2086 the parameters are the maximum likelihood estimation solutions.
2087 """
2088 cpts = []
2089 for child in data.variables():
2090 cpts.append(data.makeCPT(child,adg.parents(child),force_cpt=True, prior=prior))
2091 self = CBN(factors=cpts, domain=data, adg=adg.copy())
2092 return self
2093
2107
2109 """
2110 @param intervention: A dictionary mapping variables in the L{CBN} to
2111 a single value in the domain.
2112 """
2113 self._mutilate(intervention.keys())
2114 self.condition(intervention, keep_class=True)
2115
2117 """Replace the factor of C{child} with C{factor}, updating the
2118 hypergraph. Note the ADG must be update separately!"""
2119 hyperedge = self.family_hyperedge(child)
2120 new_hyperedge = frozenset(factor.variables())
2121 if new_hyperedge != hyperedge and not allow_hyperedge_change:
2122 raise ValueError,'new factor does not have same variables as hyperedge',hyperedge
2123 del self._factors[hyperedge]
2124 self._factors[new_hyperedge] = factor
2125
2127 """Replace the parameters of C{self} with estimates from C{data}, keeping the
2128 same structure.
2129 @param prior: the Dirichlet prior parameter (the same parameter value
2130 is used for all instances!) Note there may be some problems with
2131 this method: a B{different} prior is used by the BDeu score. However,
2132 in practice, for parameter estimation, this prior method seems to be ok.
2133 I was lazy and it was simple to implement (cb). If prior is zero, then
2134 the parameters are the maximum likelihood estimation solutions.
2135 """
2136 for child in self._variables:
2137 self._replace_factor(child, data.makeCPT(child,self._adg.parents(child),force_cpt=True, prior=prior))
2138
2140 """Extend the domains of each of the variables in C{other}
2141 to the corresponding values. The parameters corresponding to values
2142 introduced to the domains are set to zero. When conditioning, L{BN}
2143 removes elements from the domains of variables. Sometimes it is useful to
2144 add the zeroes e.g. when using an additive binary operator.
2145 @param other: a dictionary mapping variables to their domains. Note
2146 the current domain must be a subset of the new domain.
2147 """
2148 for factor in self:
2149 extension = dict([(variable, other.values(variable))
2150 for variable in other.variables() & factor.variables()])
2151 factor.data_extend(extension, keep_class=True)
2152
2153
2154 for variable in other.variables():
2155 self.change_domain_variable(variable, other.values(variable))
2156
2158 """Obtain a heuristically ``good'' set of interventions from which
2159 structures can be learnt."""
2160 forces_vars = self.good_sets_forced_variables()
2161
2162 interventions = []
2163 for forced_vars in forces_vars:
2164 forced_vars = tuple(forced_vars)
2165 for inst in self.insts(forced_vars):
2166 inst = map(lambda x: frozenset([x]), inst)
2167 interventions.append(dict(zip(forced_vars,inst)))
2168 return interventions
2169
2171 """Return the good independent sets of variables to force. This is an
2172 approximation since ideally you want the edge which implies the most
2173 propagation when resolved. The approximation will resolve all edges.
2174 However some forcing may be unnecessary."""
2175 def fanout_cmp(a,b):
2176 return cmp(fanout[b], fanout[a])
2177 interventions = []
2178 g = EssentialGraph.from_graph(self._adg.essential_graph())
2179 while g.vertices():
2180
2181 g.resolve()
2182
2183
2184 blanket = set()
2185
2186
2187 fanout = {}
2188
2189
2190 interventions.append(set())
2191
2192
2193 for v in g.vertices():
2194 f = len(g.neighbours(v))
2195 if f < 1:
2196 g.remove_vertex(v)
2197 continue
2198 fanout[v] = f
2199
2200
2201 fanout_order = sorted(fanout.keys(), cmp=fanout_cmp)
2202
2203
2204 for v in fanout_order:
2205
2206
2207 if v in blanket:
2208 continue
2209 interventions[-1].add(v)
2210
2211 blanket |= g.neighbours(v)
2212
2213
2214 g.remove_vertex(v)
2215 return interventions
2216