26 #ifndef CASADI_BINARY_MX_HPP
27 #define CASADI_BINARY_MX_HPP
29 #include "mx_node.hpp"
40 template<
bool ScX,
bool ScY>
41 class CASADI_EXPORT BinaryMX :
public MXNode {
46 BinaryMX(Operation op,
const MX& x,
const MX& y);
56 std::string disp(
const std::vector<std::string>& arg)
const override;
61 casadi_int op()
const override {
return op_;}
66 bool is_binary()
const override {
return true;}
71 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
76 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
77 std::vector<std::array<MX, 3> >& res)
const override;
82 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
83 std::vector<std::vector<MX> >& fsens)
const override;
88 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
89 std::vector<std::vector<MX> >& asens)
const override;
94 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
99 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
103 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
106 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
109 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
112 casadi_int n_inplace()
const override {
return 2;}
117 void generate(CodeGenerator& g,
118 const std::vector<casadi_int>& arg,
119 const std::vector<casadi_int>& res,
120 const std::vector<bool>& arg_is_ref,
121 std::vector<bool>& res_is_ref)
const override;
126 void serialize_body(SerializingStream& s)
const override;
131 void serialize_type(SerializingStream& s)
const override;
136 static MXNode* deserialize(DeserializingStream& s);
139 MX get_unary(casadi_int op)
const override;
142 MX _get_binary(casadi_int op,
const MX& y,
bool scX,
bool scY)
const override;
147 bool is_equal(
const MXNode* node, casadi_int depth)
const override {
148 if (op_==node->op()) {
149 if (MX::is_equal(dep(0), node->dep(0), depth-1)
150 && MX::is_equal(dep(1), node->dep(1), depth-1)) {
155 return operation_checker<CommChecker>(op_)
156 && MX::is_equal(dep(1), node->dep(0), depth-1)
157 && MX::is_equal(dep(0), node->dep(1), depth-1);
165 MX get_solve_triu(
const MX& r,
bool tr)
const override;
168 MX get_solve_tril(
const MX& r,
bool tr)
const override;
176 explicit BinaryMX(DeserializingStream& s);