26 #include "logsumexp.hpp"
35 return "logsumexp(" + arg.at(0) +
")";
39 res[0] = logsumexp(arg[0]);
43 std::vector<std::vector<MX> >& fsens)
const {
44 MX max = mmax(
dep(0));
45 MX expmm = exp(
dep(0)-max);
47 for (casadi_int d=0; d<fsens.size(); ++d) {
49 fsens[d][0] =
dot(v, expmm)/s;
54 std::vector<std::vector<MX> >& asens)
const {
55 MX max = mmax(
dep(0));
56 MX expmm = exp(
dep(0)-max);
58 for (casadi_int d=0; d<aseed.size(); ++d) {
59 asens[d][0] += expmm*aseed[d][0]/s;
63 int LogSumExp::eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
64 return eval_gen<double>(arg, res, iw, w);
74 const std::vector<casadi_int>& arg,
75 const std::vector<casadi_int>& res,
76 const std::vector<bool>& arg_is_ref,
77 std::vector<bool>& res_is_ref)
const {
79 g << g.
workel(res[0]) <<
" = "
Helper class for C code generation.
std::string logsumexp(const std::string &A, casadi_int n)
std::string work(casadi_int n, casadi_int sz, bool is_ref) const
std::string workel(casadi_int n) const
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
int eval_gen(const T **arg, T **res, casadi_int *iw, T *w) const
Evaluate the function (template)
LogSumExp(const MX &A)
Constructor.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
const Sparsity & sparsity() const
Get the sparsity.
casadi_int nnz(casadi_int i=0) const
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
void set_dep(const MX &dep)
Set unary dependency.
static Sparsity dense(casadi_int nrow, casadi_int ncol=1)
Create a dense rectangular sparsity pattern *.
T1 casadi_logsumexp(const T1 *x, casadi_int n)
T dot(const std::vector< T > &a, const std::vector< T > &b)