List of all members | Public Member Functions | Static Public Member Functions | Public Attributes
casadi::Factory< MatType > Class Template Reference

#include <factory.hpp>

Detailed Description

template<typename MatType>
class casadi::Factory< MatType >

Definition at line 52 of file factory.hpp.

Public Member Functions

Block block (const std::string &s1, const std::string &s) const
 
HBlock hblock (const std::string &s1, const std::string &s) const
 
void add_input (const std::string &s, const MatType &e, bool is_diff)
 
void add_output (const std::string &s, const MatType &e, bool is_diff)
 
void add_dual (const Function::AuxOut &aux)
 
std::string request_input (const std::string &s)
 
std::string request_output (const std::string &s)
 
void calculate_fwd (const Dict &opts)
 
void calculate_adj (const Dict &opts)
 
std::vector< Block >::iterator find_jac (size_t f, size_t x)
 
std::vector< HBlock >::iterator find_hess (size_t f, size_t x1, size_t x2)
 
void calculate_jac (const Dict &opts)
 
void calculate_grad (const Dict &opts)
 
void calculate_hess (const Dict &opts, size_t f)
 
void calculate_hess (const Dict &opts)
 
void calculate (const Dict &opts=Dict())
 
size_t imap (const std::string &s) const
 
size_t omap (const std::string &s) const
 
MatType get_input (const std::string &s)
 
MatType get_output (const std::string &s)
 
bool has_in (const std::string &s) const
 
bool has_out (const std::string &s) const
 
const std::vector< std::string > & iname () const
 
std::vector< std::string > iname (const std::vector< size_t > &ind) const
 
const std::vector< std::string > & oname () const
 
std::vector< std::string > oname (const std::vector< size_t > &ind) const
 

Static Public Member Functions

static bool has_prefix (const std::string &s)
 
static std::pair< std::string, std::string > split_prefix (const std::string &s)
 

Public Attributes

std::vector< MatType > in_
 
std::vector< MatType > out_
 
std::vector< std::string > iname_
 
std::vector< std::string > oname_
 
std::map< std::string, size_t > imap_
 
std::map< std::string, size_t > omap_
 
std::vector< bool > is_diff_in_
 
std::vector< bool > is_diff_out_
 
std::vector< size_t > fwd_in_
 
std::vector< size_t > fwd_out_
 
std::vector< size_t > adj_in_
 
std::vector< size_t > adj_out_
 
std::vector< Blockjac_
 
std::vector< Blockgrad_
 
std::vector< HBlockhess_
 

Member Function Documentation

◆ add_dual()

template<typename MatType >
void casadi::Factory< MatType >::add_dual ( const Function::AuxOut aux)

Definition at line 541 of file factory.hpp.

541  {
542  // Dual variables
543  for (size_t k = 0; k < out_.size(); ++k) {
544  Sparsity sp = is_diff_out_[k] ? out_.at(k).sparsity() : Sparsity(out_.at(k).size());
545  add_input("lam:" + oname_[k], MatType::sym("lam_" + oname_[k], sp), true);
546  }
547  // Add linear combinations
548  for (auto i : aux) {
549  MatType lc = 0;
550  for (auto j : i.second) {
551  lc += dot(in_.at(imap_.at("lam:" + j)), out_.at(omap_.at(j)));
552  }
553  add_output(i.first, lc, true);
554  }
555  }
std::map< std::string, size_t > imap_
Definition: factory.hpp:62
std::vector< MatType > out_
Definition: factory.hpp:56
std::vector< bool > is_diff_out_
Definition: factory.hpp:63
std::vector< MatType > in_
Definition: factory.hpp:56
std::vector< std::string > oname_
Definition: factory.hpp:59
std::map< std::string, size_t > omap_
Definition: factory.hpp:62
void add_input(const std::string &s, const MatType &e, bool is_diff)
Definition: factory.hpp:160
void add_output(const std::string &s, const MatType &e, bool is_diff)
Definition: factory.hpp:171
T dot(const std::vector< T > &a, const std::vector< T > &b)

References casadi::dot().

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ add_input()

template<typename MatType >
void casadi::Factory< MatType >::add_input ( const std::string &  s,
const MatType &  e,
bool  is_diff 
)

Definition at line 159 of file factory.hpp.

160  {
161  size_t ind = in_.size();
162  auto it = imap_.insert(std::make_pair(s, ind));
163  casadi_assert(it.second, "Duplicate input expression \"" + s + "\"");
164  is_diff_in_.push_back(is_diff);
165  in_.push_back(e);
166  iname_.push_back(s);
167  }
std::vector< std::string > iname_
Definition: factory.hpp:59
std::vector< bool > is_diff_in_
Definition: factory.hpp:63

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ add_output()

template<typename MatType >
void casadi::Factory< MatType >::add_output ( const std::string &  s,
const MatType &  e,
bool  is_diff 
)

Definition at line 170 of file factory.hpp.

171  {
172  size_t ind = out_.size();
173  auto it = omap_.insert(std::make_pair(s, ind));
174  casadi_assert(it.second, "Duplicate output expression \"" + s + "\"");
175  is_diff_out_.push_back(is_diff);
176  out_.push_back(e);
177  oname_.push_back(s);
178  }

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ block()

template<typename MatType >
Block casadi::Factory< MatType >::block ( const std::string &  s1,
const std::string &  s 
) const

Definition at line 681 of file factory.hpp.

681  {
682  Block b;
683  b.s = s;
684  size_t pos = s2.find(':');
685  if (pos < s2.size()) {
686  b.f = omap(s2.substr(0, pos));
687  b.x = imap(s2.substr(pos+1, std::string::npos));
688  }
689  return b;
690  }
size_t imap(const std::string &s) const
Definition: factory.hpp:665
size_t omap(const std::string &s) const
Definition: factory.hpp:673

References casadi::Block::f, casadi::Block::s, and casadi::Block::x.

◆ calculate()

template<typename MatType >
void casadi::Factory< MatType >::calculate ( const Dict opts = Dict())

Definition at line 558 of file factory.hpp.

558  {
559  // Forward mode directional derivatives
560  try {
561  calculate_fwd(opts);
562  } catch (std::exception& e) {
563  casadi_error("Forward mode AD failed:\n" + str(e.what()));
564  }
565 
566  // Reverse mode directional derivatives
567  try {
568  calculate_adj(opts);
569  } catch (std::exception& e) {
570  casadi_error("Reverse mode AD failed:\n" + str(e.what()));
571  }
572 
573  // Jacobian blocks
574  try {
575  calculate_jac(opts);
576  } catch (std::exception& e) {
577  casadi_error("Jacobian generation failed:\n" + str(e.what()));
578  }
579 
580  // Gradient blocks
581  try {
582  calculate_grad(opts);
583  } catch (std::exception& e) {
584  casadi_error("Gradient generation failed:\n" + str(e.what()));
585  }
586 
587  // Hessian blocks
588  try {
589  calculate_hess(opts);
590  } catch (std::exception& e) {
591  casadi_error("Hessian generation failed:\n" + str(e.what()));
592  }
593  }
void calculate_fwd(const Dict &opts)
Definition: factory.hpp:237
void calculate_adj(const Dict &opts)
Definition: factory.hpp:267
void calculate_grad(const Dict &opts)
Definition: factory.hpp:389
void calculate_jac(const Dict &opts)
Definition: factory.hpp:313
void calculate_hess(const Dict &opts, size_t f)
Definition: factory.hpp:404
std::string str(const T &v)
String representation, any type.

References casadi::str().

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ calculate_adj()

template<typename MatType >
void casadi::Factory< MatType >::calculate_adj ( const Dict opts)

Definition at line 267 of file factory.hpp.

267  {
268  if (adj_out_.empty()) return;
269  casadi_assert_dev(!adj_in_.empty());
270  std::vector<MatType> arg, res;
271  std::vector<std::vector<MatType>> seed(1), sens(1);
272  // Inputs
273  for (size_t ind : adj_out_) arg.push_back(in_[ind]);
274  // Outputs and reverse mode seeds
275  for (size_t ind : adj_in_) {
276  res.push_back(out_.at(ind));
277  Sparsity sp = is_diff_out_.at(ind) ? res.back().sparsity() : Sparsity(res.back().size());
278  seed[0].push_back(MatType::sym("adj_" + oname_[ind], sp));
279  add_input("adj:" + oname_[ind], seed[0].back(), true);
280  }
281  // Calculate directional derivatives
282  Dict local_opts;
283  local_opts["always_inline"] = true;
284  sens = reverse(res, arg, seed, local_opts);
285 
286  // Get directional derivatives
287  for (size_t i=0; i < adj_out_.size(); ++i) {
288  std::string s = iname_[adj_out_[i]];
289  Sparsity sp = is_diff_in_.at(adj_out_[i]) ? arg.at(i).sparsity() : Sparsity(arg.at(i).size());
290  add_output("adj:" + s, project(sens[0].at(i), sp), is_diff_in_.at(adj_out_[i]));
291  }
292  }
std::vector< size_t > adj_in_
Definition: factory.hpp:69
std::vector< size_t > adj_out_
Definition: factory.hpp:69
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
std::vector< T > reverse(const std::vector< T > &v)
Reverse a list.

References casadi::reverse().

◆ calculate_fwd()

template<typename MatType >
void casadi::Factory< MatType >::calculate_fwd ( const Dict opts)

Definition at line 237 of file factory.hpp.

237  {
238  if (fwd_out_.empty()) return;
239  casadi_assert_dev(!fwd_in_.empty());
240 
241  std::vector<MatType> arg, res;
242  std::vector<std::vector<MatType>> seed(1), sens(1);
243  // Inputs and forward mode seeds
244  for (size_t iind : fwd_in_) {
245  arg.push_back(in_[iind]);
246  Sparsity sp = is_diff_in_.at(iind) ? arg.back().sparsity() : Sparsity(arg.back().size());
247  seed[0].push_back(MatType::sym("fwd_" + iname_[iind], sp));
248  add_input("fwd:" + iname_[iind], seed[0].back(), true);
249  }
250  // Outputs
251  for (size_t oind : fwd_out_) res.push_back(out_.at(oind));
252  // Calculate directional derivatives
253  Dict local_opts = opts;
254  local_opts["always_inline"] = true;
255  sens = forward(res, arg, seed, local_opts);
256 
257  // Get directional derivatives
258  for (size_t i = 0; i < fwd_out_.size(); ++i) {
259  std::string s = oname_.at(fwd_out_[i]);
260  Sparsity sp = is_diff_out_.at(fwd_out_[i]) ? res.at(i).sparsity()
261  : Sparsity(res.at(i).size());
262  add_output("fwd:" + s, project(sens[0].at(i), sp), is_diff_out_.at(fwd_out_[i]));
263  }
264  }
std::vector< size_t > fwd_out_
Definition: factory.hpp:66
std::vector< size_t > fwd_in_
Definition: factory.hpp:66

◆ calculate_grad()

template<typename MatType >
void casadi::Factory< MatType >::calculate_grad ( const Dict opts)

Definition at line 389 of file factory.hpp.

389  {
390  for (auto &&b : grad_) {
391  const MatType& ex = out_.at(b.f);
392  const MatType& arg = in_[b.x];
393  if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x)) {
394  add_output("grad:" + oname_[b.f] + ":" + iname_[b.x],
395  project(gradient(ex, arg, opts), arg.sparsity()), true);
396  } else {
397  casadi_assert(ex.is_scalar(), "Can only take gradient of scalar expression.");
398  add_output("grad:" + oname_[b.f] + ":" + iname_[b.x], MatType(1, arg.numel()), false);
399  }
400  }
401  }
std::vector< Block > grad_
Definition: factory.hpp:72

◆ calculate_hess() [1/2]

template<typename MatType >
void casadi::Factory< MatType >::calculate_hess ( const Dict opts)

Definition at line 518 of file factory.hpp.

518  {
519  // Calculate blocks for all non-differentiable inputs and outputs
520  for (auto &&b : hess_) {
521  if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x1) && is_diff_in_.at(b.x2)) {
522  b.calculated = false;
523  } else {
524  add_output(b.s, MatType(in_[b.x1].numel(), in_[b.x2].numel()), false);
525  b.calculated = true;
526  }
527  // Consistency check
528  casadi_assert(out_.at(b.f).is_scalar(),
529  "Can only take Hessian of scalar expression.");
530  }
531  // Calculate regular blocks
532  for (auto &&b : hess_) {
533  // Skip if already calculated
534  if (b.calculated) continue;
535  // Calculate all Hessian blocks for b.f
536  calculate_hess(opts, b.f);
537  }
538  }
std::vector< HBlock > hess_
Definition: factory.hpp:75

◆ calculate_hess() [2/2]

template<typename MatType >
void casadi::Factory< MatType >::calculate_hess ( const Dict opts,
size_t  f 
)

Definition at line 404 of file factory.hpp.

404  {
405  // Handle all blocks for this expression
406  for (auto &&b : hess_) {
407  if (b.f != f) continue;
408  // Skip if already calculated
409  if (b.calculated) continue;
410  // Find other blocks with one of the arguments matching
411  std::vector<size_t> all_x1;
412  for (auto &&b1 : hess_) {
413  if (b1.f == b.f && !b1.calculated) {
414  if (b1.x1 == b.x1) {
415  // Block found
416  all_x1.push_back(b1.x2);
417  } else if (b1.x2 == b.x1) {
418  // Opposite block found
419  all_x1.push_back(b1.x1);
420  }
421  }
422  }
423  // Find other blocks with both of the arguments matching
424  std::vector<size_t> all_x2;
425  for (auto &&b1 : hess_) {
426  if (b1.f != f || b1.calculated) continue;
427  // Can either b1.x1 or b1.x2 be added to all_x2?
428  for (bool test_x1 : {false, true}) {
429  size_t cand = test_x1 ? b1.x1 : b1.x2;
430  size_t other = test_x1 ? b1.x2 : b1.x1;
431  bool cand_ok = true;
432  bool other_ok = false;
433  // Skip if already in all_x2
434  if (std::count(all_x2.begin(), all_x2.end(), cand)) continue;
435  // Loop over existing entries in x1
436  for (size_t a : all_x1) {
437  // The other argument must already be in all_x1
438  if (other == a) other_ok = true;
439  // Is block not requested?
440  auto it = find_hess(f, a, cand);
441  if (it == hess_.end() || it->calculated) {
442  // Also check mirror block, if there is one
443  if (a != cand) {
444  it = find_hess(f, cand, a);
445  if (it != hess_.end() && !it->calculated) continue;
446  }
447  // Not a candidate
448  cand_ok = false;
449  break;
450  }
451  }
452  // Keep candidate
453  if (cand_ok && other_ok) all_x2.push_back(cand);
454  }
455  }
456  // Calculate Hessian blocks
457  try {
458  if (all_x1.size() == 1 && all_x2.size() == 1) {
459  // Single block
460  MatType H = b.x1 == b.x2 ? hessian(out_.at(f), in_[b.x1], opts)
461  : jacobian(gradient(out_.at(f), in_[b.x1]), in_[b.x2]);
462  add_output(b.s, H, true);
463  b.calculated = true;
464  } else {
465  // Sort blocks
466  std::sort(all_x1.begin(), all_x1.end());
467  std::sort(all_x2.begin(), all_x2.end());
468  // Symmetric extended Hessian?
469  bool symmetric = all_x1 == all_x2;
470  // Collect components
471  std::vector<MatType> x1(all_x1.size()), x2(all_x2.size());
472  for (size_t i = 0; i < x1.size(); ++i) x1[i] = in_.at(all_x1[i]);
473  for (size_t i = 0; i < x2.size(); ++i) x2[i] = in_.at(all_x2[i]);
474  // Calculate extended Hessian
475  MatType H;
476  if (symmetric) {
477  H = hessian(out_.at(f), vertcat(x1));
478  } else {
479  H = jacobian(gradient(out_.at(f), vertcat(x1)), vertcat(x2));
480  }
481  // Split into blocks
482  std::vector<std::vector<MatType>> H_all = blocksplit(H, offset(x1), offset(x2));
483  // Collect Hessian blocks
484  for (auto &&b1 : hess_) {
485  if (b1.f == f && !b1.calculated) {
486  // Find arguments in all_x1 and all_x2
487  auto it_x1 = std::find(all_x1.begin(), all_x1.end(), b1.x1);
488  auto it_x2 = std::find(all_x2.begin(), all_x2.end(), b1.x2);
489  if (it_x1 != all_x1.end() && it_x2 != all_x2.end()) {
490  // Block located
491  const MatType& Hb = H_all.at(it_x1 - all_x1.begin()).at(it_x2 - all_x2.begin());
492  add_output(b1.s, Hb, true);
493  b1.calculated = true;
494  } else if (!symmetric) {
495  // Check mirror block
496  it_x1 = std::find(all_x1.begin(), all_x1.end(), b1.x2);
497  it_x2 = std::find(all_x2.begin(), all_x2.end(), b1.x1);
498  if (it_x1 != all_x1.end() && it_x2 != all_x2.end()) {
499  // Transpose of block located
500  const MatType& Hb = H_all.at(it_x1 - all_x1.begin()).at(it_x2 - all_x2.begin());
501  add_output(b1.s, Hb.T(), true);
502  b1.calculated = true;
503  }
504  }
505  }
506  }
507  }
508  } catch (std::exception& e) {
509  std::stringstream ss;
510  ss << "Calculating Hessian of " << oname_.at(f) << " w.r.t. " << iname(all_x1) << " and "
511  << iname(all_x2) << ": " << e.what();
512  casadi_error(ss.str());
513  }
514  }
515  }
std::vector< HBlock >::iterator find_hess(size_t f, size_t x1, size_t x2)
Definition: factory.hpp:304
const std::vector< std::string > & iname() const
Definition: factory.hpp:150

◆ calculate_jac()

template<typename MatType >
void casadi::Factory< MatType >::calculate_jac ( const Dict opts)

Definition at line 313 of file factory.hpp.

313  {
314  // Calculate blocks for all non-differentiable inputs and outputs
315  for (auto &&b : jac_) {
316  if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x)) {
317  b.calculated = false;
318  } else {
319  add_output(b.s, MatType(out_[b.f].numel(), in_[b.x].numel()), false);
320  b.calculated = true;
321  }
322  }
323  // Calculate regular blocks
324  for (auto &&b : jac_) {
325  // Skip if already calculated
326  if (b.calculated) continue;
327  // Find other blocks with the same input, but different (not yet calculated) outputs
328  std::vector<size_t> all_f;
329  for (auto &&b1 : jac_) {
330  if (b1.x == b.x && !b1.calculated) all_f.push_back(b1.f);
331  }
332  // Now find other blocks with *all* the same outputs, but different inputs
333  std::vector<size_t> all_x{b.x};
334  for (auto &&b1 : jac_) {
335  // Candidate b1.arg: Check if already added
336  if (std::count(all_x.begin(), all_x.end(), b1.x)) continue;
337  // Skip if all block are not requested or any block has already been calculated
338  bool skip = false;
339  for (size_t f1 : all_f) {
340  auto it = find_jac(f1, b1.x);
341  if (it == jac_.end() || it->calculated) {
342  skip = true;
343  break;
344  }
345  }
346  if (skip) continue;
347  // Keep candidate
348  all_x.push_back(b1.x);
349  }
350  try {
351  // Calculate Jacobian block(s)
352  if (all_f.size() == 1 && all_x.size() == 1) {
353  // Single block
354  add_output(b.s, MatType::jacobian(out_[b.f], in_[b.x], opts), true);
355  b.calculated = true;
356  } else {
357  // Sort blocks
358  std::sort(all_x.begin(), all_x.end());
359  std::sort(all_f.begin(), all_f.end());
360  // Collect components
361  std::vector<MatType> x(all_x.size()), f(all_f.size());
362  for (size_t i = 0; i < x.size(); ++i) x[i] = in_.at(all_x[i]);
363  for (size_t i = 0; i < f.size(); ++i) f[i] = out_.at(all_f[i]);
364  // Calculate Jacobian of all outputs with respect to all inputs
365  MatType J = MatType::jacobian(veccat(f), veccat(x), opts);
366  // Split Jacobian into blocks
367  std::vector<std::vector<MatType>> J_all = blocksplit(J, offset(f), offset(x));
368  // Save blocks
369  for (size_t i = 0; i < all_f.size(); ++i) {
370  for (size_t j = 0; j < all_x.size(); ++j) {
371  auto J_it = find_jac(all_f[i], all_x[j]);
372  if (J_it != jac_.end()) {
373  add_output(J_it->s, J_all.at(i).at(j), true);
374  J_it->calculated = true;
375  }
376  }
377  }
378  }
379  } catch (std::exception& e) {
380  std::stringstream ss;
381  ss << "Calculating Jacobian of " << oname(all_f) << " w.r.t. " << iname(all_x)
382  << ": " << e.what();
383  casadi_error(ss.str());
384  }
385  }
386  }
std::vector< Block >::iterator find_jac(size_t f, size_t x)
Definition: factory.hpp:295
std::vector< Block > jac_
Definition: factory.hpp:72
const std::vector< std::string > & oname() const
Definition: factory.hpp:154

◆ find_hess()

template<typename MatType >
std::vector< HBlock >::iterator casadi::Factory< MatType >::find_hess ( size_t  f,
size_t  x1,
size_t  x2 
)

Definition at line 304 of file factory.hpp.

304  {
305  for (std::vector<HBlock>::iterator it = hess_.begin(); it != hess_.end(); ++it) {
306  if (it->f == f && it->x1 == x1 && it->x2 == x2) return it;
307  }
308  // Not in list
309  return hess_.end();
310  }

◆ find_jac()

template<typename MatType >
std::vector< Block >::iterator casadi::Factory< MatType >::find_jac ( size_t  f,
size_t  x 
)

Definition at line 295 of file factory.hpp.

295  {
296  for (std::vector<Block>::iterator it = jac_.begin(); it != jac_.end(); ++it) {
297  if (it->f == f && it->x == x) return it;
298  }
299  // Not in list
300  return jac_.end();
301  }

◆ get_input()

template<typename MatType >
MatType casadi::Factory< MatType >::get_input ( const std::string &  s)

Definition at line 596 of file factory.hpp.

596  {
597  auto it = imap_.find(s);
598  casadi_assert(it!=imap_.end(), "Cannot retrieve \"" + s + "\"");
599  return in_.at(it->second);
600  }

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ get_output()

template<typename MatType >
MatType casadi::Factory< MatType >::get_output ( const std::string &  s)

Definition at line 603 of file factory.hpp.

603  {
604  // Quick return if output
605  auto it = omap_.find(s);
606  if (it!=omap_.end()) return out_.at(it->second);
607 
608  // Assume attribute
609  casadi_assert(has_prefix(s), "Cannot process \"" + s + "\"");
610  std::pair<std::string, std::string> ss = split_prefix(s);
611  std::string a = ss.first;
612  MatType r = get_output(ss.second);
613 
614  // Process attributes
615  if (a=="transpose") {
616  return r.T();
617  } else if (a=="triu") {
618  return triu(r);
619  } else if (a=="tril") {
620  return tril(r);
621  } else if (a=="densify") {
622  return densify(r);
623  } else if (a=="sym") {
624  casadi_warning("Attribute 'sym' has been deprecated. Hessians are symmetric by default.");
625  return r;
626  } else if (a=="withdiag") {
627  return project(r, r.sparsity() + Sparsity::diag(r.size1()));
628  } else {
629  casadi_error("Cannot process attribute \"" + a + "\"");
630  return MatType();
631  }
632  }
static std::pair< std::string, std::string > split_prefix(const std::string &s)
Definition: factory.hpp:642
MatType get_output(const std::string &s)
Definition: factory.hpp:603
static bool has_prefix(const std::string &s)
Definition: factory.hpp:636
static Sparsity diag(casadi_int nrow)
Create diagonal sparsity pattern *.
Definition: sparsity.hpp:190

References casadi::Sparsity::diag(), and casadi::has_prefix().

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ has_in()

template<typename MatType >
bool casadi::Factory< MatType >::has_in ( const std::string &  s) const
inline

Definition at line 144 of file factory.hpp.

144 { return imap_.find(s)!=imap_.end();}

References casadi::Factory< MatType >::imap_.

◆ has_out()

template<typename MatType >
bool casadi::Factory< MatType >::has_out ( const std::string &  s) const
inline

Definition at line 147 of file factory.hpp.

147 { return omap_.find(s) != omap_.end();}

References casadi::Factory< MatType >::omap_.

◆ has_prefix()

template<typename MatType >
bool casadi::Factory< MatType >::has_prefix ( const std::string &  s)
static

Definition at line 635 of file factory.hpp.

636  {
637  return s.find(':') < s.size();
638  }

◆ hblock()

template<typename MatType >
HBlock casadi::Factory< MatType >::hblock ( const std::string &  s1,
const std::string &  s 
) const

Definition at line 693 of file factory.hpp.

693  {
694  HBlock b;
695  b.s = s;
696  size_t pos1 = s2.find(':');
697  if (pos1 < s2.size()) {
698  size_t pos2 = s2.find(':', pos1 + 1);
699  if (pos2 < s2.size()) {
700  b.f = omap(s2.substr(0, pos1));
701  b.x1 = imap(s2.substr(pos1 + 1, pos2 - pos1 - 1));
702  b.x2 = imap(s2.substr(pos2 + 1, std::string::npos));
703  }
704  }
705  return b;
706  }

References casadi::HBlock::f, casadi::HBlock::s, casadi::HBlock::x1, and casadi::HBlock::x2.

◆ imap()

template<typename MatType >
size_t casadi::Factory< MatType >::imap ( const std::string &  s) const

Definition at line 665 of file factory.hpp.

665  {
666  auto iind = imap_.find(s);
667  casadi_assert(iind != imap_.end(),
668  "Cannot process \"" + s + "\" as input. Available: " + join(oname()) + ".");
669  return iind->second;
670  }
std::string join(const std::vector< std::string > &l, const std::string &delim)

References casadi::join().

◆ iname() [1/2]

template<typename MatType >
const std::vector<std::string>& casadi::Factory< MatType >::iname ( ) const
inline

Definition at line 150 of file factory.hpp.

150 {return iname_;}

References casadi::Factory< MatType >::iname_.

◆ iname() [2/2]

template<typename MatType >
std::vector< std::string > casadi::Factory< MatType >::iname ( const std::vector< size_t > &  ind) const

Definition at line 651 of file factory.hpp.

651  {
652  std::vector<std::string> ret;
653  for (size_t i : ind) ret.push_back(iname_.at(i));
654  return ret;
655  }

◆ omap()

template<typename MatType >
size_t casadi::Factory< MatType >::omap ( const std::string &  s) const

Definition at line 673 of file factory.hpp.

673  {
674  auto oind = omap_.find(s);
675  casadi_assert(oind != omap_.end(),
676  "Cannot process \"" + s + "\" as output. Available: " + join(oname()) + ".");
677  return oind->second;
678  }

References casadi::join().

◆ oname() [1/2]

template<typename MatType >
const std::vector<std::string>& casadi::Factory< MatType >::oname ( ) const
inline

Definition at line 154 of file factory.hpp.

154 {return oname_;}

References casadi::Factory< MatType >::oname_.

◆ oname() [2/2]

template<typename MatType >
std::vector< std::string > casadi::Factory< MatType >::oname ( const std::vector< size_t > &  ind) const

Definition at line 658 of file factory.hpp.

658  {
659  std::vector<std::string> ret;
660  for (size_t i : ind) ret.push_back(oname_.at(i));
661  return ret;
662  }

◆ request_input()

template<typename MatType >
std::string casadi::Factory< MatType >::request_input ( const std::string &  s)

Definition at line 181 of file factory.hpp.

182  {
183  // Add input if not already available
184  if (!has_in(s)) {
185  // Get prefix
186  casadi_assert(has_prefix(s), "Cannot process \"" + s + "\" as input."
187  " Available: " + join(iname()) + ".");
188  std::pair<std::string, std::string> ss = split_prefix(s);
189  // Process specific prefixes
190  if (ss.first=="fwd") {
191  // Forward mode directional derivative
192  fwd_in_.push_back(imap(ss.second));
193  } else if (ss.first=="adj") {
194  // Reverse mode directional derivative
195  adj_in_.push_back(omap(ss.second));
196  }
197  }
198  // Replace colons with underscore
199  std::string ret = s;
200  std::replace(ret.begin(), ret.end(), ':', '_');
201  return ret;
202  }
bool has_in(const std::string &s) const
Definition: factory.hpp:144

References casadi::has_prefix(), and casadi::join().

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ request_output()

template<typename MatType >
std::string casadi::Factory< MatType >::request_output ( const std::string &  s)

Definition at line 205 of file factory.hpp.

206  {
207  // Quick return if already available
208  if (has_out(s)) return s;
209 
210  // Get prefix
211  casadi_assert(has_prefix(s), "Cannot process \"" + s + "\" as output."
212  " Available: " + join(oname()) + ".");
213  std::pair<std::string, std::string> ss = split_prefix(s);
214 
215  if (ss.first=="fwd") {
216  fwd_out_.push_back(omap(ss.second));
217  } else if (ss.first=="adj") {
218  adj_out_.push_back(imap(ss.second));
219  } else if (ss.first=="jac") {
220  jac_.push_back(block(ss.second, s));
221  } else if (ss.first=="grad") {
222  grad_.push_back(block(ss.second, s));
223  } else if (ss.first=="hess") {
224  hess_.push_back(hblock(ss.second, s));
225  } else {
226  // Assume attribute
227  request_output(ss.second);
228  }
229 
230  // Replace colons with underscore
231  std::string ret = s;
232  replace(ret.begin(), ret.end(), ':', '_');
233  return ret;
234  }
std::string request_output(const std::string &s)
Definition: factory.hpp:206
bool has_out(const std::string &s) const
Definition: factory.hpp:147
Block block(const std::string &s1, const std::string &s) const
Definition: factory.hpp:681
HBlock hblock(const std::string &s1, const std::string &s) const
Definition: factory.hpp:693
CASADI_EXPORT std::string replace(const std::string &s, const std::string &p, const std::string &r)
Replace all occurences of p with r in s.

References casadi::has_prefix(), casadi::join(), and casadi::replace().

Referenced by casadi::XFunction< DerivedType, MatType, NodeType >::factory().

◆ split_prefix()

template<typename MatType >
std::pair< std::string, std::string > casadi::Factory< MatType >::split_prefix ( const std::string &  s)
static

Definition at line 641 of file factory.hpp.

642  {
643  // Get prefix
644  casadi_assert_dev(!s.empty());
645  size_t pos = s.find(':');
646  casadi_assert(pos<s.size(), "Cannot process \"" + s + "\"");
647  return std::make_pair(s.substr(0, pos), s.substr(pos+1, std::string::npos));
648  }

Member Data Documentation

◆ adj_in_

template<typename MatType >
std::vector<size_t> casadi::Factory< MatType >::adj_in_

Definition at line 69 of file factory.hpp.

◆ adj_out_

template<typename MatType >
std::vector<size_t> casadi::Factory< MatType >::adj_out_

Definition at line 69 of file factory.hpp.

◆ fwd_in_

template<typename MatType >
std::vector<size_t> casadi::Factory< MatType >::fwd_in_

Definition at line 66 of file factory.hpp.

◆ fwd_out_

template<typename MatType >
std::vector<size_t> casadi::Factory< MatType >::fwd_out_

Definition at line 66 of file factory.hpp.

◆ grad_

template<typename MatType >
std::vector<Block> casadi::Factory< MatType >::grad_

Definition at line 72 of file factory.hpp.

◆ hess_

template<typename MatType >
std::vector<HBlock> casadi::Factory< MatType >::hess_

Definition at line 75 of file factory.hpp.

◆ imap_

template<typename MatType >
std::map<std::string, size_t> casadi::Factory< MatType >::imap_

Definition at line 62 of file factory.hpp.

Referenced by casadi::Factory< MatType >::has_in().

◆ in_

template<typename MatType >
std::vector<MatType> casadi::Factory< MatType >::in_

Definition at line 56 of file factory.hpp.

◆ iname_

template<typename MatType >
std::vector<std::string> casadi::Factory< MatType >::iname_

Definition at line 59 of file factory.hpp.

Referenced by casadi::Factory< MatType >::iname().

◆ is_diff_in_

template<typename MatType >
std::vector<bool> casadi::Factory< MatType >::is_diff_in_

Definition at line 63 of file factory.hpp.

◆ is_diff_out_

template<typename MatType >
std::vector<bool> casadi::Factory< MatType >::is_diff_out_

Definition at line 63 of file factory.hpp.

◆ jac_

template<typename MatType >
std::vector<Block> casadi::Factory< MatType >::jac_

Definition at line 72 of file factory.hpp.

◆ omap_

template<typename MatType >
std::map<std::string, size_t> casadi::Factory< MatType >::omap_

Definition at line 62 of file factory.hpp.

Referenced by casadi::Factory< MatType >::has_out().

◆ oname_

template<typename MatType >
std::vector<std::string> casadi::Factory< MatType >::oname_

Definition at line 59 of file factory.hpp.

Referenced by casadi::Factory< MatType >::oname().

◆ out_

template<typename MatType >
std::vector<MatType> casadi::Factory< MatType >::out_

Definition at line 56 of file factory.hpp.


The documentation for this class was generated from the following file: