26 #include "constant_mx.hpp"
29 #include "casadi_misc.hpp"
30 #include "serializing_stream.hpp"
67 split_primitives_gen<MX>(x, it);
71 split_primitives_gen<SX>(x, it);
75 split_primitives_gen<DM>(x, it);
89 return shared_from_this<MX>();
96 return join_primitives_gen<SX>(it);
100 return join_primitives_gen<DM>(it);
104 res[0] = shared_from_this<MX>();
108 std::vector<std::vector<MX> >& fsens)
const {
110 for (casadi_int d=0; d<fsens.size(); ++d) {
111 fsens[d][0] = zero_sens;
116 std::vector<std::vector<MX> >& asens)
const {
120 std::fill_n(res[0],
nnz(), 0);
125 std::fill_n(res[0],
nnz(), 0);
130 const std::vector<casadi_int>& arg,
131 const std::vector<casadi_int>& res,
132 const std::vector<bool>& arg_is_ref,
133 std::vector<bool>& res_is_ref)
const {
138 g << g.
work(res[0],
nnz(),
true) <<
" = " <<
ind <<
";\n";
139 res_is_ref[0] =
true;
147 if (
numel()!=1) casadi_error(
"Can only determine truth value of scalar MX.");
148 if (
nnz()!=1) casadi_error(
"Can only determine truth value of dense scalar MX.");
169 casadi_int intval =
static_cast<casadi_int
>(val);
170 if (
static_cast<double>(intval)-val==0) {
171 return create(sp, intval);
185 const std::vector<double> vdata = val.
nonzeros();
187 for (
auto&& i : vdata) {
252 if (n==
nullptr)
return false;
258 if (!std::equal(
x_->begin(),
x_->end(), n->
x_->begin()))
return false;
268 return shared_from_this<MX>();
272 casadi_assert_dev(nz.empty());
277 return shared_from_this<MX>();
281 return shared_from_this<MX>();
285 return shared_from_this<MX>();
289 return shared_from_this<MX>();
299 s.
pack(
"ConstantMX::type",
'a');
309 std::vector<double> v;
310 s.
unpack(
"ConstantMX::nonzeros", v);
316 s.
pack(
"ConstantMX::type",
'z');
325 s.
unpack(
"ConstantMX::type", t);
343 casadi_error(
"Error deserializing");
349 s.
pack(
"ConstantFile::type",
'f');
355 s.
pack(
"ConstantFile::x",
x_);
367 if (ret==1) casadi_error(
"Cannot open file '" +
str(fname) +
"'.");
368 if (ret==2) casadi_error(
"Failed to read a double from '" +
str(fname) +
"'. "
369 "Expected " +
str(sp.
nnz()) +
" doubles.");
377 casadi_error(
"Not defined for ConstantFile");
381 casadi_error(
"Not defined for ConstantFile");
393 const std::vector<casadi_int>& arg,
394 const std::vector<casadi_int>& res,
395 const std::vector<bool>& arg_is_ref,
396 std::vector<bool>& res_is_ref)
const {
401 res_is_ref[0] =
true;
408 ConstantMX(x.sparsity()), name_(name), x_(x.nonzeros()) {
416 casadi_error(
"Not defined for ConstantPool");
420 casadi_error(
"Not defined for ConstantPool");
424 const std::vector<casadi_int>& arg,
425 const std::vector<casadi_int>& res,
426 const std::vector<bool>& arg_is_ref,
427 std::vector<bool>& res_is_ref)
const {
432 res_is_ref[0] =
true;
446 s.
pack(
"ConstantPool::x",
x_);
451 s.
pack(
"ConstantPool::type",
'p');
Helper class for C code generation.
void define_pool_double(const std::string &name, const std::vector< double > &def)
Allocate file scope double writeable memory.
std::string work(casadi_int n, casadi_int sz, bool is_ref) const
std::string pool_double(const std::string &name) const
Access file scope double writeable memory.
std::string copy(const std::string &arg, std::size_t n, const std::string &res)
Create a copy operation.
std::string constant(const std::vector< casadi_int > &v)
Represent an array constant; adding it when new.
std::string rom_double(const void *id) const
Access file scope double read-only memory.
std::string workel(casadi_int n) const
void define_rom_double(const void *id, casadi_int size)
Allocate file scope double read-only memory.
void add_include(const std::string &new_include, bool relative_path=false, const std::string &use_ifdef=std::string())
Add an include file optionally using a relative path "..." instead of an absolute path <....
std::string file_slurp(const std::string &fname, casadi_int n, const std::string &a)
Slurp a file.
bool elide_copy(casadi_int sz)
A constant given as a DM.
void serialize_type(SerializingStream &s) const override
Serialize type information.
bool is_equal(const MXNode *node, casadi_int depth) const override
Check if two nodes are equivalent up to a given depth.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
bool is_one() const override
Check if identically one.
Matrix< double > x_
data member
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
bool is_zero() const override
Check if a particular integer value.
ConstantDM(const Matrix< double > &x)
Constructor.
bool is_eye() const override
Check if identity matrix.
bool is_minus_one() const override
Check if identically minus one.
Matrix< double > get_DM() const override
Get the value (only for constant nodes)
A constant to be read from a file.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
Matrix< double > get_DM() const override
Get the value (only for constant nodes)
double to_double() const override
Get the value (only for scalar constant nodes)
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
void add_dependency(CodeGenerator &g) const override
Add a dependent function.
void serialize_type(SerializingStream &s) const override
Serialize type information.
ConstantFile(const Sparsity &x, const std::string &fname)
Constructor.
std::vector< double > x_
nonzeros
void codegen_incref(CodeGenerator &g, std::set< void * > &added) const override
Codegen incref.
std::string fname_
file to read from
Represents an MX that is only composed of a constant.
T join_primitives_gen(typename std::vector< T >::const_iterator &it) const
Join an expression along symbolic primitives (template)
void primitives(std::vector< MX >::iterator &it) const override
Get symbolic primitives.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
Matrix< double > get_DM() const override=0
Get the value (only for constant nodes)
MX join_primitives(std::vector< MX >::const_iterator &it) const override
Join an expression along symbolic primitives.
void split_primitives(const MX &x, std::vector< MX >::iterator &it) const override
Split up an expression along symbolic primitives.
~ConstantMX() override=0
Destructor.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
casadi_int n_primitives() const override
Get the number of symbolic primitives.
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
static ConstantMX * create(const Sparsity &sp, casadi_int val)
MX get_dot(const MX &y) const override
Matrix multiplication.
ConstantMX(const Sparsity &sp)
Destructor.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
bool __nonzero__() const override
Return truth value of an MX.
void split_primitives_gen(const T &x, typename std::vector< T >::iterator &it) const
Split up an expression along primitives (template)
bool is_valid_input() const override
Check if valid function input.
A constant to be managed by a pool.
std::vector< double > x_
nonzeros
Matrix< double > get_DM() const override
Get the value (only for constant nodes)
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
double to_double() const override
Get the value (only for scalar constant nodes)
void add_dependency(CodeGenerator &g) const override
Add a dependent function.
ConstantPool(const DM &x, const std::string &name)
Constructor.
void serialize_type(SerializingStream &s) const override
Serialize type information.
std::string name_
pool identifier
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
A constant with all entries identical.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
casadi_int nnz() const
Get the number of (structural) non-zero elements.
static MX zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
bool is_scalar(bool scalar_and_dense=false) const
Check if the matrix expression is scalar.
Node class for MX objects.
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual const std::string & name() const
Get the name.
virtual casadi_int n_primitives() const
Get the number of symbolic primitives.
virtual bool is_zero() const
Check if identically zero.
virtual DM get_DM() const
Get the value (only for constant nodes)
virtual casadi_int ind() const
virtual MX join_primitives(std::vector< MX >::const_iterator &it) const
Join an expression along symbolic primitives.
virtual MX get_dot(const MX &y) const
Inner product.
Sparsity sparsity_
The sparsity pattern.
casadi_int numel() const
Get shape.
const Sparsity & sparsity() const
Get the sparsity.
casadi_int nnz(casadi_int i=0) const
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
virtual void primitives(std::vector< MX >::iterator &it) const
Get symbolic primitives.
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
virtual void split_primitives(const MX &x, std::vector< MX >::iterator &it) const
Split up an expression along symbolic primitives.
bool is_constant() const
Check if constant.
Sparse matrix class. SX and DM are specializations.
std::vector< Scalar > & nonzeros()
bool is_one() const
check if the matrix is 1 (note that false negative answers are possible)
const Sparsity & sparsity() const
Const access the sparsity - reference to data member.
bool is_minus_one() const
check if the matrix is -1 (note that false negative answers are possible)
bool is_eye() const
check if the matrix is an identity matrix (note that false negative answers
bool is_zero() const
check if the matrix is 0 (note that false negative answers are possible)
std::string get_str(bool more=false) const
Get string representation.
const Scalar scalar() const
Convert to scalar type.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
casadi_int nnz() const
Get the number of (structural) non-zeros.
bool is_empty(bool both=false) const
Check if the sparsity is empty.
MX get_nzassign(const MX &y, const std::vector< casadi_int > &nz) const override
Assign the nonzeros of a matrix to another matrix.
MX get_transpose() const override
Transpose.
MX get_reshape(const Sparsity &sp) const override
Reshape.
static ZeroByZero * getInstance()
Get a pointer to the singleton.
MX _get_binary(casadi_int op, const MX &y, bool ScX, bool ScY) const override
Get a binary operation operation.
void serialize_body(SerializingStream &s) const override
Serialize type information.
MX get_project(const Sparsity &sp) const override
Get densification.
MX get_unary(casadi_int op) const override
Get a unary operation.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
void serialize_type(SerializingStream &s) const override
Serialize specific part of node.
MX get_nzref(const Sparsity &sp, const std::vector< casadi_int > &nz) const override
Get the nonzeros of matrix.
unsigned long long bvec_t
std::string str(const T &v)
String representation, any type.
T dot(const std::vector< T > &a, const std::vector< T > &b)
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.
static CompiletimeConst deserialize(DeserializingStream &s)
static RuntimeConst deserialize(DeserializingStream &s)