27 #include "casadi_misc.hpp"
28 #include "global_options.hpp"
29 #include "serializing_stream.hpp"
44 Split::Split(
const MX& x,
const std::vector<casadi_int>& offset) : offset_(offset) {
52 int Split::eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
53 return eval_gen<double>(arg, res, iw, w);
57 return eval_gen<SXElem>(arg, res, iw, w);
63 casadi_int nx =
offset_.size()-1;
65 for (casadi_int i=0; i<nx; ++i) {
66 casadi_int nz_first =
offset_[i];
67 casadi_int nz_last =
offset_[i+1];
68 if (res[i]!=
nullptr) {
69 std::copy(arg[0]+nz_first, arg[0]+nz_last, res[i]);
76 casadi_int nx =
offset_.size()-1;
77 for (casadi_int i=0; i<nx; ++i) {
78 if (res[i]!=
nullptr) {
81 bvec_t *res_i_ptr = res[i];
82 for (casadi_int k=0; k<n_i; ++k) {
83 *res_i_ptr++ = *arg_ptr++;
91 casadi_int nx =
offset_.size()-1;
92 for (casadi_int i=0; i<nx; ++i) {
93 if (res[i]!=
nullptr) {
96 bvec_t *res_i_ptr = res[i];
97 for (casadi_int k=0; k<n_i; ++k) {
98 *arg_ptr++ |= *res_i_ptr;
107 const std::vector<casadi_int>& arg,
108 const std::vector<casadi_int>& res,
109 const std::vector<bool>& arg_is_ref,
110 std::vector<bool>& res_is_ref)
const {
111 casadi_int nx =
nout();
112 for (casadi_int i=0; i<nx; ++i) {
113 casadi_int nz_first =
offset_[i];
114 casadi_int nz_last =
offset_[i+1];
115 casadi_int nz = nz_last-nz_first;
116 if (res[i]>=0 && nz>0) {
118 g << g.
workel(res[i]) <<
" = ";
121 casadi_assert_dev(nz_first==0);
122 g << g.
workel(arg[0]) <<
";\n";
125 g << g.
work(arg[0],
dep(0).
nnz(), arg_is_ref[0]) <<
"[" << nz_first <<
"];\n";
129 std::string r = g.
work(arg[0],
dep(0).
nnz(), arg_is_ref[0]);
130 if (nz_first!=0) r = r +
"+" +
str(nz_first);
132 g << g.
work(res[i],
nnz(i),
true) <<
" = " << r <<
";\n";
133 res_is_ref[i] =
true;
135 g << g.
copy(r, nz, g.
work(res[i],
nnz(i),
false)) <<
"\n";
145 arg.push_back(
MX::sym(
"x", sp));
146 Function output(
"output", std::vector<MX>{}, arg, {{
"allow_free",
true}});
147 return {{
"offset",
offset_}, {
"output", output}};
163 return "horzsplit(" + arg.at(0) +
")";
168 std::vector<casadi_int> col_offset;
169 col_offset.reserve(
offset_.size());
170 col_offset.push_back(0);
172 col_offset.push_back(col_offset.back() + s.size2());
175 res = horzsplit(arg[0], col_offset);
179 std::vector<std::vector<MX> >& fsens)
const {
180 casadi_int nfwd = fsens.size();
183 std::vector<casadi_int> col_offset;
184 col_offset.reserve(
offset_.size());
185 col_offset.push_back(0);
187 col_offset.push_back(col_offset.back() + s.size2());
191 for (casadi_int d=0; d<nfwd; ++d) {
192 fsens[d] = horzsplit(fseed[d][0], col_offset);
197 std::vector<std::vector<MX> >& asens)
const {
198 casadi_int nadj = aseed.size();
201 std::vector<casadi_int> col_offset;
202 col_offset.reserve(
offset_.size());
203 col_offset.push_back(0);
205 col_offset.push_back(col_offset.back() + s.size2());
208 for (casadi_int d=0; d<nadj; ++d) {
209 asens[d][0] += horzcat(aseed[d]);
214 const std::vector<casadi_int>& offset1,
215 const std::vector<casadi_int>& offset2) :
Split(x, offset1) {
227 "DiagSplit:: the presence of nonzeros outside the diagonal blocks in unsupported.");
231 return "diagsplit(" + arg.at(0) +
")";
236 std::vector<casadi_int> offset1;
237 offset1.reserve(
offset_.size());
238 offset1.push_back(0);
239 std::vector<casadi_int> offset2;
240 offset2.reserve(
offset_.size());
241 offset2.push_back(0);
243 offset1.push_back(offset1.back() + s.size1());
244 offset2.push_back(offset2.back() + s.size2());
247 res = diagsplit(arg[0], offset1, offset2);
251 std::vector<std::vector<MX> >& fsens)
const {
252 casadi_int nfwd = fsens.size();
254 std::vector<casadi_int> offset1;
255 offset1.reserve(
offset_.size());
256 offset1.push_back(0);
257 std::vector<casadi_int> offset2;
258 offset2.reserve(
offset_.size());
259 offset2.push_back(0);
261 offset1.push_back(offset1.back() + s.size1());
262 offset2.push_back(offset2.back() + s.size2());
266 for (casadi_int d=0; d<nfwd; ++d) {
267 fsens[d] = diagsplit(fseed[d][0], offset1, offset2);
272 std::vector<std::vector<MX> >& asens)
const {
273 casadi_int nadj = asens.size();
276 std::vector<casadi_int> offset1;
277 offset1.reserve(
offset_.size());
278 offset1.push_back(0);
279 std::vector<casadi_int> offset2;
280 offset2.reserve(
offset_.size());
281 offset2.push_back(0);
283 offset1.push_back(offset1.back() + s.size1());
284 offset2.push_back(offset2.back() + s.size2());
287 for (casadi_int d=0; d<nadj; ++d) {
288 asens[d][0] += diagcat(aseed[d]);
305 return "vertsplit(" + arg.at(0) +
")";
310 std::vector<casadi_int> row_offset;
311 row_offset.reserve(
offset_.size());
312 row_offset.push_back(0);
314 row_offset.push_back(row_offset.back() + s.size1());
317 res = vertsplit(arg[0], row_offset);
321 std::vector<std::vector<MX> >& fsens)
const {
322 casadi_int nfwd = fsens.size();
325 std::vector<casadi_int> row_offset;
326 row_offset.reserve(
offset_.size());
327 row_offset.push_back(0);
329 row_offset.push_back(row_offset.back() + s.size1());
332 for (casadi_int d=0; d<nfwd; ++d) {
333 fsens[d] = vertsplit(fseed[d][0], row_offset);
338 std::vector<std::vector<MX> >& asens)
const {
339 casadi_int nadj = aseed.size();
342 std::vector<casadi_int> row_offset;
343 row_offset.reserve(
offset_.size());
344 row_offset.push_back(0);
346 row_offset.push_back(row_offset.back() + s.size1());
349 for (casadi_int d=0; d<nadj; ++d) {
350 asens[d][0] += vertcat(aseed[d]);
357 if (e.nnz()!=0)
X.push_back(e);
361 if (
X.size()!=
nout()) {
366 for (casadi_int i=0; i<
X.size(); ++i) {
377 if (x.size()!=
nout()) {
382 for (casadi_int i=0; i<x.size(); ++i) {
394 if (x.size()!=
nout()) {
399 for (casadi_int i=0; i<x.size(); ++i) {
Helper class for C code generation.
std::string work(casadi_int n, casadi_int sz, bool is_ref) const
std::string copy(const std::string &arg, std::size_t n, const std::string &res)
Create a copy operation.
std::string workel(casadi_int n) const
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
MX get_diagcat(const std::vector< MX > &x) const override
Create a diagonal concatenation node.
Diagsplit(const MX &x, const std::vector< casadi_int > &offset1, const std::vector< casadi_int > &offset2)
Constructor.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
casadi_int nnz() const
Get the number of (structural) non-zero elements.
static MX sym(const std::string &name, casadi_int nrow=1, casadi_int ncol=1)
Create an nrow-by-ncol symbolic primitive.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
Horzsplit(const MX &x, const std::vector< casadi_int > &offset)
Constructor.
MX get_horzcat(const std::vector< MX > &x) const override
Create a horizontal concatenation node.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
virtual MX get_diagcat(const std::vector< MX > &x) const
Create a diagonal concatenation node.
const Sparsity & sparsity() const
Get the sparsity.
casadi_int nnz(casadi_int i=0) const
virtual casadi_int which_output() const
Get function output.
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
virtual MX get_horzcat(const std::vector< MX > &x) const
Create a horizontal concatenation node.
virtual MX get_vertcat(const std::vector< MX > &x) const
Create a vertical concatenation node (vectors only)
void set_dep(const MX &dep)
Set unary dependency.
virtual bool is_output() const
Check if evaluation output.
const Sparsity & sparsity() const
Get the sparsity pattern.
The basic scalar symbolic class of CasADi.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
casadi_int nnz() const
Get the number of (structural) non-zeros.
static Sparsity scalar(bool dense_scalar=true)
Create a scalar sparsity pattern *.
Split: Split into multiple expressions splitting the nonzeros.
Split(const MX &x, const std::vector< casadi_int > &offset)
Constructor.
Dict info() const override
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
std::vector< Sparsity > output_sparsity_
int eval_gen(const T **arg, T **res, casadi_int *iw, T *w) const
Evaluate the function (template)
std::vector< casadi_int > offset_
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
casadi_int nout() const override
Number of outputs.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
~Split() override=0
Destructor.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
Vertsplit(const MX &x, const std::vector< casadi_int > &offset)
Constructor.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
MX get_vertcat(const std::vector< MX > &x) const override
Create a vertical concatenation node (vectors only)
unsigned long long bvec_t
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.