26 #ifndef CASADI_BINARY_MX_HPP
27 #define CASADI_BINARY_MX_HPP
29 #include "mx_node.hpp"
40 template<
bool ScX,
bool ScY>
41 class CASADI_EXPORT BinaryMX :
public MXNode {
46 BinaryMX(Operation op,
const MX& x,
const MX& y);
56 std::string disp(
const std::vector<std::string>& arg)
const override;
61 casadi_int op()
const override {
return op_;}
66 bool is_binary()
const override {
return true;}
71 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
76 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
77 std::vector<std::vector<MX> >& fsens)
const override;
82 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
83 std::vector<std::vector<MX> >& asens)
const override;
88 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
93 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
97 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
100 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
103 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
106 casadi_int n_inplace()
const override {
return 2;}
111 void generate(CodeGenerator& g,
112 const std::vector<casadi_int>& arg,
113 const std::vector<casadi_int>& res)
const override;
118 void serialize_body(SerializingStream& s)
const override;
123 void serialize_type(SerializingStream& s)
const override;
128 static MXNode* deserialize(DeserializingStream& s);
131 MX get_unary(casadi_int op)
const override;
134 MX _get_binary(casadi_int op,
const MX& y,
bool scX,
bool scY)
const override;
139 bool is_equal(
const MXNode* node, casadi_int depth)
const override {
140 if (op_==node->op()) {
141 if (MX::is_equal(dep(0), node->dep(0), depth-1)
142 && MX::is_equal(dep(1), node->dep(1), depth-1)) {
147 return operation_checker<CommChecker>(op_)
148 && MX::is_equal(dep(1), node->dep(0), depth-1)
149 && MX::is_equal(dep(0), node->dep(1), depth-1);
157 MX get_solve_triu(
const MX& r,
bool tr)
const override;
160 MX get_solve_tril(
const MX& r,
bool tr)
const override;
168 explicit BinaryMX(DeserializingStream& s);