Skip to content

Commit

Permalink
fix gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian.froehlich committed Sep 10, 2016
1 parent 987d58b commit b84dd78
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
1 change: 1 addition & 0 deletions logL_SCTL.m
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
RR = zeros(size(data.SCTL.T(:,:,i)));
[ logL, bhat, Sim ] = logL_SCTL_si(xi, model, data, s, options, P, i);
% [g,g_fd_f,g_fd_b,g_fd_c]=testGradient(xi,@(xi) logL_SCTL_si(xi, model, data, s, options, P, i),1e-3,'val','dxi')
% [g,g_fd_f,g_fd_b,g_fd_c]=testGradient(xi,@(xi) logL_SCTL_si(xi, model, data, s, options, P, i),1e-3,'I','Idxi')
logLi_D(i,1) = logL.D;
logLi_T(i,1) = logL.T;
logLi_b(i,1) = logL.b;
Expand Down
29 changes: 18 additions & 11 deletions logL_SCTL_si.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
% derivatives with respect to beta and delta
%
% testing:
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(beta,@(beta) getBhat( beta, delta, bhat_si0, model, data, s, i, options, P),1e-5,'val','dbeta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(delta,@(delta) getBhat( beta, delta, bhat_si0, model, data, s, i, options, P),1e-5,'val','ddelta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(beta,@(beta) getBhat_B( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','dbeta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(delta,@(delta) getBhat_B( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','ddelta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(beta,@(beta) getBhat_J( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','dbeta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(delta,@(delta) getBhat_J( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','ddelta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(beta,@(beta) getBhat_G( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','dbeta')
% [g,g_fd_b,g_fd_f,g_fd_c] = testGradient(delta,@(delta) getBhat_G( beta, delta, bhat_si0, model, data, s, i, options, P),1e-3,'val','ddelta')

beta = model.beta(xi);
delta = model.delta(xi);
Expand Down Expand Up @@ -50,7 +54,7 @@
dbdxi = chainrule(dbhat_sidbeta,dbetadxi) + chainrule(dbhat_siddelta,ddeltadxi);
bhat.dxi = dbdxi;

logL.dxi = chainrule(J.db,dbdxi);
logL.dxi = - chainrule(J.db,dbdxi) - chainrule(J.dbeta,dbetadxi) - chainrule(J.ddelta,ddeltadxi);

if(options.integration)
% laplace approximation
Expand All @@ -59,10 +63,10 @@
G.dbeta = G.dbeta + chainrule(G.db,dbhat_sidbeta);
G.ddelta = G.ddelta + chainrule(G.db,dbhat_siddelta);
G.dxi = chainrule(G.dbeta,dbetadxi) + chainrule(G.ddelta,ddeltadxi);

logL.dxi = logL.dxi - 0.5*permute(sum(sum(bsxfun(@times,permute(sum(bsxfun(@times,invG,permute(G.dxi,[4,1,2,3])),2),[1,3,4,2]),eye(length(bhat.val))),1),2),[1,3,2]); % 1/2*Tr(invG*dG)
logL.Idxi = - 0.5*permute(sum(sum(bsxfun(@times,permute(sum(bsxfun(@times,invG,permute(G.dxi,[4,1,2,3])),2),[1,3,4,2]),eye(length(bhat.val))),1),2),[1,3,2]); % 1/2*Tr(invG*dG)
logL.dxi = logL.dxi + logL.Idxi;
end
%%
% if options.nderiv >= 2
% % second order derivatives
%
Expand Down Expand Up @@ -162,17 +166,20 @@
% end
%
% end

%%
end
end

function J = objective_phi_J_D(model,data,phi,s,i,options,nderiv)
[J,~] = objective_phi(model,data,phi,s,i,options,nderiv);
function B = getBhat_B( beta, delta, bhat_si0, model, data, s, i, options, P)
[B,G,J,Sim] = getBhat( beta, delta, bhat_si0, model, data, s, i, options, P);
end

function G = getBhat_G( beta, delta, bhat_si0, model, data, s, i, options, P)
[B,G,J,Sim] = getBhat( beta, delta, bhat_si0, model, data, s, i, options, P);
end

function J = objective_phi_J_T(model,data,phi,s,i,options,nderiv)
[~,J] = objective_phi(model,data,phi,s,i,options,nderiv);
function J = getBhat_J( beta, delta, bhat_si0, model, data, s, i, options, P)
[B,G,J,Sim] = getBhat( beta, delta, bhat_si0, model, data, s, i, options, P);
end


4 changes: 2 additions & 2 deletions normal_param.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
% dJ_bdb
J_b.db = transpose(invD*b);
% dJ_bddelta
J_b.ddelta = transpose(0.5*permute(sum(sum(bsxfun(@times,dinvDddelta,bsxfun(@times,permute(b,[2,1]),permute(b,[1,2]))),1),2),[1,3,2]) ... % 1/2*b*dinvD*b
+0.5*permute(sum(sum(sum(bsxfun(@times,invD.*eye(length(b)),permute(dDddelta,[4,1,2,3])),2),1),3),[1,4,3,2])); % 1/2*Tr(invD*dD)
J_b.ddelta = 0.5*permute(sum(sum(bsxfun(@times,dinvDddelta,bsxfun(@times,permute(b,[2,1]),permute(b,[1,2]))),1),2),[1,3,2]) ... % 1/2*b*dinvD*b
+0.5*permute(sum(sum(sum(bsxfun(@times,invD.*eye(length(b)),permute(dDddelta,[4,1,2,3])),2),1),3),[1,4,3,2]); % 1/2*Tr(invD*dD)
if nderiv >= 2
% ddJ_bdbdb
J_b.dbdb = invD;
Expand Down

0 comments on commit b84dd78

Please sign in to comment.