26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
38 template<
bool ScX,
bool ScY>
39 BinaryMX<ScX, ScY>::BinaryMX(Operation op,
const MX& x,
const MX& y) : op_(op) {
48 template<
bool ScX,
bool ScY>
49 BinaryMX<ScX, ScY>::~BinaryMX() {
52 template<
bool ScX,
bool ScY>
53 std::string BinaryMX<ScX, ScY>::disp(
const std::vector<std::string>& arg)
const {
54 return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
57 template<
bool ScX,
bool ScY>
58 void BinaryMX<ScX, ScY>::eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const {
59 casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
62 template<
bool ScX,
bool ScY>
63 void BinaryMX<ScX, ScY>::ad_forward(
const std::vector<std::vector<MX> >& fseed,
64 std::vector<std::vector<MX> >& fsens)
const {
67 casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
70 for (casadi_int d=0; d<fsens.size(); ++d) {
71 if (op_ == OP_IF_ELSE_ZERO) {
72 fsens[d][0] = if_else_zero(pd[1], fseed[d][1]);
74 fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
79 template<
bool ScX,
bool ScY>
80 void BinaryMX<ScX, ScY>::ad_reverse(
const std::vector<std::vector<MX> >& aseed,
81 std::vector<std::vector<MX> >& asens)
const {
84 casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
87 for (casadi_int d=0; d<aseed.size(); ++d) {
89 if (op_ == OP_IF_ELSE_ZERO) {
91 if (!s.
is_scalar() && dep(1).is_scalar()) {
92 asens[d][1] += dot(dep(0), s);
94 asens[d][1] += if_else_zero(dep(0), s);
98 for (casadi_int c=0; c<2; ++c) {
104 if (pd[c].size()!=s.
size()) pd[c] =
MX(s.sparsity(), pd[c]);
115 template<
bool ScX,
bool ScY>
116 void BinaryMX<ScX, ScY>::
118 const std::vector<casadi_int>& arg,
const std::vector<casadi_int>& res)
const {
120 if (nnz()==0)
return;
129 inplace = res[0]==arg[0];
137 std::string r = g.workel(res[0]);
138 std::string x = g.workel(arg[0]);
139 std::string y = g.workel(arg[1]);
142 if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
149 g.local(
"rr",
"casadi_real",
"*");
150 g.local(
"i",
"casadi_int");
151 g <<
"for (i=0, " <<
"rr=" << g.work(res[0], nnz());
155 if (!ScX && !inplace) {
156 g.local(
"cr",
"const casadi_real",
"*");
157 g <<
", cr=" << g.work(arg[0], dep(0).nnz());
158 if (op_==OP_OR || op_==OP_AND) {
169 g.local(
"cs",
"const casadi_real",
"*");
170 g <<
", cs=" << g.work(arg[1], dep(1).nnz());
171 if (op_==OP_OR || op_==OP_AND || op_==OP_IF_ELSE_ZERO) {
180 g <<
"; i<" << nnz() <<
"; ++i) ";
186 g << casadi_math<double>::sep(op_) <<
"= " << y;
188 g <<
" = " << g.print_op(op_, x, y);
193 template<
bool ScX,
bool ScY>
194 int BinaryMX<ScX, ScY>::
195 eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
196 return eval_gen<double>(arg, res, iw, w);
199 template<
bool ScX,
bool ScY>
200 int BinaryMX<ScX, ScY>::
202 return eval_gen<SXElem>(arg, res, iw, w);
205 template<
bool ScX,
bool ScY>
207 int BinaryMX<ScX, ScY>::
208 eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const {
211 const T* input0 = arg[0];
212 const T* input1 = arg[1];
215 casadi_math<T>::fun(op_, input0, input1, output0, nnz());
217 casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
219 casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
224 template<
bool ScX,
bool ScY>
225 int BinaryMX<ScX, ScY>::
226 sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const {
227 const bvec_t *a0=arg[0], *a1=arg[1];
230 for (casadi_int i=0; i<n; ++i) {
233 else if (ScX && !ScY)
235 else if (!ScX && ScY)
238 *r++ = *a0++ | *a1++;
243 template<
bool ScX,
bool ScY>
244 int BinaryMX<ScX, ScY>::
245 sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const {
246 bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
248 for (casadi_int i=0; i<n; ++i) {
263 template<
bool ScX,
bool ScY>
264 MX BinaryMX<ScX, ScY>::get_unary(casadi_int op)
const {
273 template<
bool ScX,
bool ScY>
274 MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op,
const MX& y,
bool scX,
bool scY)
const {
275 if (!GlobalOptions::simplification_on_the_fly)
return MXNode::_get_binary(op, y, scX, scY);
279 if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth()))
return dep(1);
280 if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth()))
return dep(0);
283 if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth()))
return -dep(1);
284 if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth()))
return dep(0);
293 template<
bool ScX,
bool ScY>
296 s.
pack(
"BinaryMX::op",
static_cast<int>(op_));
299 template<
bool ScX,
bool ScY>
304 char type = type_x | (type_y << 1);
305 s.
pack(
"BinaryMX::scalar_flags", type);
308 template<
bool ScX,
bool ScY>
311 s.
unpack(
"BinaryMX::scalar_flags", t);
316 if (scY)
return new BinaryMX<true, true>(s);
317 return new BinaryMX<true, false>(s);
319 if (scY)
return new BinaryMX<false, true>(s);
320 return new BinaryMX<false, false>(s);
324 template<
bool ScX,
bool ScY>
327 s.unpack(
"BinaryMX::op", op);
331 template<
bool ScX,
bool ScY>
332 MX BinaryMX<ScX, ScY>::get_solve_triu(
const MX& r,
bool tr)
const {
334 if (!ScX && !ScY && op_ == OP_SUB) {
336 if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
338 if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_triu(
true)) {
339 return dep(1).
dep(0)->get_solve_triu_unity(r, tr);
347 template<
bool ScX,
bool ScY>
348 MX BinaryMX<ScX, ScY>::get_solve_tril(
const MX& r,
bool tr)
const {
350 if (!ScX && !ScY && op_ == OP_SUB) {
352 if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
354 if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_tril(
true)) {
355 return dep(1).
dep(0)->get_solve_tril_unity(r, tr);
Helper class for C code generation.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Sparsity sparsity() const
Get the sparsity pattern.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
bool is_scalar(bool scalar_and_dense=false) const
Check if the matrix expression is scalar.
Node class for MX objects.
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
virtual MX _get_binary(casadi_int op, const MX &y, bool scX, bool scY) const
Get a binary operation operation (matrix-matrix)
virtual MX get_unary(casadi_int op) const
Get a unary operation.
virtual MX get_solve_triu(const MX &r, bool tr) const
Solve a system of linear equations, upper triangular A.
virtual MX get_solve_tril(const MX &r, bool tr) const
Solve a system of linear equations, lower triangular A.
MX dep(casadi_int ch=0) const
Get the nth dependency as MX.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.