26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
38 template<
bool ScX,
bool ScY>
39 BinaryMX<ScX, ScY>::BinaryMX(Operation op,
const MX& x,
const MX& y) : op_(op) {
42 set_sparsity(y.sparsity());
44 set_sparsity(x.sparsity());
48 template<
bool ScX,
bool ScY>
49 BinaryMX<ScX, ScY>::~BinaryMX() {
52 template<
bool ScX,
bool ScY>
53 std::string BinaryMX<ScX, ScY>::disp(
const std::vector<std::string>& arg)
const {
54 return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
57 template<
bool ScX,
bool ScY>
58 void BinaryMX<ScX, ScY>::eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const {
59 casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
62 template<
bool ScX,
bool ScY>
63 void BinaryMX<ScX, ScY>::eval_linear(
const std::vector<std::array<MX, 3> >& arg,
64 std::vector<std::array<MX, 3> >& res)
const {
65 casadi_math<MX>::fun_linear(op_, arg[0].data(), arg[1].data(), res[0].data());
68 template<
bool ScX,
bool ScY>
69 void BinaryMX<ScX, ScY>::ad_forward(
const std::vector<std::vector<MX> >& fseed,
70 std::vector<std::vector<MX> >& fsens)
const {
73 casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
76 for (casadi_int d=0; d<fsens.size(); ++d) {
77 if (op_ == OP_IF_ELSE_ZERO) {
78 fsens[d][0] = if_else_zero(pd[1], fseed[d][1]);
80 fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
85 template<
bool ScX,
bool ScY>
86 void BinaryMX<ScX, ScY>::ad_reverse(
const std::vector<std::vector<MX> >& aseed,
87 std::vector<std::vector<MX> >& asens)
const {
90 casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
93 for (casadi_int d=0; d<aseed.size(); ++d) {
95 if (op_ == OP_IF_ELSE_ZERO) {
97 if (!s.is_scalar() && dep(1).is_scalar()) {
98 asens[d][1] += dot(dep(0), s);
100 asens[d][1] += if_else_zero(dep(0), s);
104 for (casadi_int c=0; c<2; ++c) {
109 if (!t.is_scalar() && t.size() != dep(c).size()) {
110 if (pd[c].size()!=s.size()) pd[c] = MX(s.sparsity(), pd[c]);
121 template<
bool ScX,
bool ScY>
122 void BinaryMX<ScX, ScY>::
123 generate(CodeGenerator& g,
124 const std::vector<casadi_int>& arg,
const std::vector<casadi_int>& res,
125 const std::vector<bool>& arg_is_ref, std::vector<bool>& res_is_ref)
const {
127 if (nnz()==0)
return;
136 inplace = res[0]==arg[0] && !arg_is_ref[0];
144 std::string r = g.workel(res[0]);
145 std::string x = g.workel(arg[0]);
146 std::string y = g.workel(arg[1]);
149 if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
156 g.local(
"rr",
"casadi_real",
"*");
157 g.local(
"i",
"casadi_int");
158 g <<
"for (i=0, " <<
"rr=" << g.work(res[0], nnz(),
false);
162 if (!ScX && !inplace) {
163 g.local(
"cr",
"const casadi_real",
"*");
164 g <<
", cr=" << g.work(arg[0], dep(0).nnz(), arg_is_ref[0]);
165 if (op_==OP_OR || op_==OP_AND) {
176 g.local(
"cs",
"const casadi_real",
"*");
177 g <<
", cs=" << g.work(arg[1], dep(1).nnz(), arg_is_ref[1]);
178 if (op_==OP_OR || op_==OP_AND || op_==OP_IF_ELSE_ZERO) {
187 g <<
"; i<" << nnz() <<
"; ++i) ";
193 g << casadi_math<double>::sep(op_) <<
"= " << y;
195 g <<
" = " << g.print_op(op_, x, y);
200 template<
bool ScX,
bool ScY>
201 int BinaryMX<ScX, ScY>::
202 eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
203 return eval_gen<double>(arg, res, iw, w);
206 template<
bool ScX,
bool ScY>
207 int BinaryMX<ScX, ScY>::
208 eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const {
209 return eval_gen<SXElem>(arg, res, iw, w);
212 template<
bool ScX,
bool ScY>
214 int BinaryMX<ScX, ScY>::
215 eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const {
218 const T* input0 = arg[0];
219 const T* input1 = arg[1];
222 casadi_math<T>::fun(op_, input0, input1, output0, nnz());
224 casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
226 casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
231 template<
bool ScX,
bool ScY>
232 int BinaryMX<ScX, ScY>::
233 sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const {
234 const bvec_t *a0=arg[0], *a1=arg[1];
237 for (casadi_int i=0; i<n; ++i) {
240 else if (ScX && !ScY)
242 else if (!ScX && ScY)
245 *r++ = *a0++ | *a1++;
250 template<
bool ScX,
bool ScY>
251 int BinaryMX<ScX, ScY>::
252 sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const {
253 bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
255 for (casadi_int i=0; i<n; ++i) {
270 template<
bool ScX,
bool ScY>
271 MX BinaryMX<ScX, ScY>::get_unary(casadi_int op)
const {
277 return MXNode::get_unary(op);
280 template<
bool ScX,
bool ScY>
281 MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op,
const MX& y,
bool scX,
bool scY)
const {
282 if (!GlobalOptions::simplification_on_the_fly)
return MXNode::_get_binary(op, y, scX, scY);
286 if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth()))
return dep(1);
287 if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth()))
return dep(0);
290 if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth()))
return -dep(1);
291 if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth()))
return dep(0);
297 return MXNode::_get_binary(op, y, scX, scY);
300 template<
bool ScX,
bool ScY>
301 void BinaryMX<ScX, ScY>::serialize_body(SerializingStream& s)
const {
302 MXNode::serialize_body(s);
303 s.pack(
"BinaryMX::op",
static_cast<int>(op_));
306 template<
bool ScX,
bool ScY>
307 void BinaryMX<ScX, ScY>::serialize_type(SerializingStream& s)
const {
308 MXNode::serialize_type(s);
311 char type = type_x | (type_y << 1);
312 s.pack(
"BinaryMX::scalar_flags", type);
315 template<
bool ScX,
bool ScY>
316 MXNode* BinaryMX<ScX, ScY>::deserialize(DeserializingStream& s) {
318 s.unpack(
"BinaryMX::scalar_flags", t);
323 if (scY)
return new BinaryMX<true, true>(s);
324 return new BinaryMX<true, false>(s);
326 if (scY)
return new BinaryMX<false, true>(s);
327 return new BinaryMX<false, false>(s);
331 template<
bool ScX,
bool ScY>
332 BinaryMX<ScX, ScY>::BinaryMX(DeserializingStream& s) : MXNode(s) {
334 s.unpack(
"BinaryMX::op", op);
338 template<
bool ScX,
bool ScY>
339 MX BinaryMX<ScX, ScY>::get_solve_triu(
const MX& r,
bool tr)
const {
341 if (!ScX && !ScY && op_ == OP_SUB) {
343 if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
345 if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_triu(
true)) {
346 return dep(1).dep(0)->get_solve_triu_unity(r, tr);
351 return MXNode::get_solve_triu(r, tr);
354 template<
bool ScX,
bool ScY>
355 MX BinaryMX<ScX, ScY>::get_solve_tril(
const MX& r,
bool tr)
const {
357 if (!ScX && !ScY && op_ == OP_SUB) {
359 if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
361 if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_tril(
true)) {
362 return dep(1).dep(0)->get_solve_tril_unity(r, tr);
367 return MXNode::get_solve_tril(r, tr);