! Demonstration of the EM algorithm for fitting two lines to a bunch ! of points. In the E step, we try to estimate the probability that ! each point is actually on each of the lines. Then in the M step ! we redraw the lines to maxmimize the likelihood of the points. ! ! Rob Malouf (malouf@let.rug.nl), 25 May 2000 ! ! Alfa Informatica ! University of Groningen ! Postbus 716 ! 9737NT Groningen ! ! This program uses the PGPLOT subroutine library, available at: ! ! ftp://astro.caltech.edu/pub/pgplot/ ! program em1 character :: key real :: x(10), y(10), model(2,2), yd(10), w1(10), w2(10), r1(10), r2(10), c real, parameter :: s2 = 0.1 integer :: seed(1) ! Set up display if (pgopen('/XWINDOW') .lt. 1) stop ! set up data call system_clock(count=seed(1)) call random_seed(put=seed) call random_number(model) x = (/ .1, .2, .3, .4, .5, .6, .7, .8, .9, 1. /) ! y(1:5) = x(1:5)*model(1,1) + model(1,2) ! y(6:10) = -x(6:10)*model(2,1) + model(2,2) y(1:5) = x(1:5) + 1; y(6:10) = -x(6:10); ! Initialize model call random_number(model) do while (.true.) ! E step r1 = model(1,1) * x + model(1,2) - y; r2 = model(2,1) * x + model(2,2) - y; w1 = exp(-r1*r1/s2) / (exp(-r1*r1/s2) + exp(-r2*r2/s2)); w2 = exp(-r2*r2/s2) / (exp(-r1*r1/s2) + exp(-r2*r2/s2)); ! Display results ! call pgsubp(1, 3) call pgsubp(1, 2) call pgsci(1) call pgenv(0., 1., minval(y), maxval(y), 0, 0) call pglab('x', 'y', 'EM Algorithm') call pgsci(2) call pgpt(10, x, y, 2) call pgsci(3) yd = x * model(1,1) + model(1,2) call pgline(10, x, yd) call pgsci(4) yd = x * model(2,1) + model(2,2) call pgline(10, x, yd) call pgsci(1) call pgenv(0., 1., 0., 1., 0, 0) call pgsci(2) call pgbin(10, x, w1, .true.) ! call pgbcht(10, 1, w1, label, 0., 1., '', 1., 2) ! call pgsci(1) ! call pgenv(0., 1., 0., 1., 0, 0) ! call pgsci(2) ! call pgbin(10, x, w2, .false.) ! call pgbcht(10, 1, w2, label, 0., 1., '', 1., 2) ! call pgclos ! M step c = 1./(sum(w1*x*x)*sum(w1)-(sum(w1*x)*sum(w1*x))); model(1,1) = c * (sum(w1) * sum(w1*x*y) - sum(w1*x) * sum(w1*y)); model(1,2) = c * (sum(w1*x*x) * sum(w1*y) -sum(w1*x) * sum(w1*x*y)); c = 1./(sum(w2*x*x)*sum(w2)-(sum(w2*x)*sum(w2*x))); model(2,1) = c * (sum(w2) * sum(w2*x*y) - sum(w2*x) * sum(w2*y)); model(2,2) = c * (sum(w2*x*x) * sum(w2*y) - sum(w2*x) * sum(w2*x*y)); end do end program em1