sx_node.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "sx_node.hpp"
27 #include "serializing_stream.hpp"
28 #include "unary_sx.hpp"
29 #include "binary_sx.hpp"
30 #include "constant_sx.hpp"
31 #include "symbolic_sx.hpp"
32 #include "call_sx.hpp"
33 #include "output_sx.hpp"
34 
35 #include <limits>
36 #include <stack>
37 
38 namespace casadi {
39 
41  count = 0;
42  temp = 0;
43  }
44 
46  #ifdef WITH_REFCOUNT_WARNINGS
47  // Make sure that this is there are no scalar expressions pointing to it when it is destroyed
48  if (count!=0) {
49  // Note that casadi_assert_warning cannot be used in destructors
50  std::cerr << "Reference counting failure." <<
51  "Possible cause: Circular dependency in user code." << std::endl;
52  }
53  #endif // WITH_REFCOUNT_WARNINGS
54  }
55 
56  double SXNode::to_double() const {
57  return std::numeric_limits<double>::quiet_NaN();
58  }
59 
60  casadi_int SXNode::to_int() const {
61  casadi_error("to_int not defined for " + class_name());
62  }
63 
65  casadi_error("'which_function' not defined for class " + class_name());
66  }
67 
68  casadi_int SXNode::which_output() const {
69  casadi_error("'which_output' not defined for class " + class_name());
70  }
71 
72  bool SXNode::is_equal(const SXNode* node, casadi_int depth) const {
73  return false;
74  }
75 
76  const std::string& SXNode::name() const {
77  casadi_error("'name' not defined for " + class_name());
78  }
79 
80  const SXElem& SXNode::dep(casadi_int i) const {
81  casadi_error("'dep' not defined for " + class_name());
82  }
83 
84  SXElem& SXNode::dep(casadi_int i) {
85  casadi_error("'dep' not defined for " + class_name());
86  }
87 
88  void SXNode::disp(std::ostream& stream, bool more) const {
89  // Find out which noded can be inlined
90  std::map<const SXNode*, casadi_int> nodeind;
91  can_inline(nodeind);
92 
93  // Print expression
94  std::vector<std::string> intermed;
95  std::string s = print_compact(nodeind, intermed);
96 
97  // Print intermediate expressions
98  for (casadi_int i=0; i<intermed.size(); ++i)
99  stream << "@" << (i+1) << "=" << intermed[i] << ", ";
100 
101  // Print this
102  stream << s;
103  }
104 
105  bool SXNode::marked() const {
106  return temp<0;
107  }
108 
109  void SXNode::mark() const {
110  temp = -temp-1;
111  }
112 
113  void SXNode::can_inline(std::map<const SXNode*, casadi_int>& nodeind) const {
114  // Add or mark node in map
115  std::map<const SXNode*, casadi_int>::iterator it=nodeind.find(this);
116  if (it==nodeind.end()) {
117  // First time encountered, mark inlined
118  nodeind.insert(it, std::make_pair(this, 0));
119 
120  // Handle dependencies with recursion
121  for (casadi_int i=0; i<n_dep(); ++i) {
122  dep(i)->can_inline(nodeind);
123  }
124  } else if (it->second==0 && op()!=OP_PARAMETER) {
125  // Node encountered before, do not inline (except if symbolic primitive)
126  it->second = -1;
127  }
128  }
129 
130  std::string SXNode::print_compact(std::map<const SXNode*, casadi_int>& nodeind,
131  std::vector<std::string>& intermed) const {
132  // Get reference to node index
133  casadi_int& ind = nodeind[this];
134 
135  // If positive, already in intermediate expressions
136  if (ind>0) {
137  std::stringstream ss;
138  ss << "@" << ind;
139  return ss.str();
140  }
141 
142  std::string s;
143  if (op()==OP_CALL) {
144  const Function& f = which_function();
145  // Get expressions for dependencies
146  s = which_function().name() + "(";
147 
148  casadi_int k = 0;
149  for (casadi_int i=0; i<f.n_in(); ++i) {
150  if (f.nnz_in(i)>1) s += "[";
151  for (casadi_int j=0; j<f.nnz_in(i); ++j) {
152  s += dep(k++)->print_compact(nodeind, intermed);
153  if (j<f.nnz_in(i)-1) s+=",";
154  }
155  if (f.nnz_in(i)>1) s += "]";
156  if (i<f.n_in()-1) s+=",";
157  }
158  s += ")";
159  } else {
160  // Get expressions for dependencies
161  std::string arg[2];
162  for (casadi_int i=0; i<n_dep(); ++i) {
163  arg[i] = dep(i)->print_compact(nodeind, intermed);
164  }
165 
166  // Get expression for this
167  s = print(arg[0], arg[1]);
168  }
169 
170  // Decide what to do with the expression
171  if (ind==0) {
172  // Inline expression
173  return s;
174  } else {
175  // Add to list of intermediate expressions and return reference
176  intermed.push_back(s);
177  ind = intermed.size(); // For subsequent references
178  std::stringstream ss;
179  ss << "@" << ind;
180  return ss.str();
181  }
182  }
183 
185  // Quick return if more owners
186  if (n->count>0) return;
187  // Delete straight away if it doesn't have any dependencies
188  if (!n->n_dep()) {
189  delete n;
190  return;
191  }
192  // Stack of expressions to be deleted
193  std::stack<SXNode*> deletion_stack;
194  // Add the node to the deletion stack
195  deletion_stack.push(n);
196  // Process stack
197  while (!deletion_stack.empty()) {
198  // Top element
199  SXNode *t = deletion_stack.top();
200  // Check if the top element has dependencies with dependencies
201  bool added_to_stack = false;
202  for (casadi_int c2=0; c2<t->n_dep(); ++c2) { // for all dependencies of the dependency
203  // Get the node of the dependency of the top element
204  // and remove it from the smart pointer
206  // Check if this is the only reference to the element
207  if (n2->count == 0) {
208  // Check if unary or binary
209  if (!n2->n_dep()) {
210  // Delete straight away if not binary
211  delete n2;
212  } else {
213  // Add to deletion stack
214  deletion_stack.push(n2);
215  added_to_stack = true;
216  }
217  }
218  }
219  // Delete and pop from stack if nothing added to the stack
220  if (!added_to_stack) {
221  delete deletion_stack.top();
222  deletion_stack.pop();
223  }
224  }
225  }
226 
227  SXElem SXNode::get_output(casadi_int oind) const {
228  casadi_assert(oind==0, "Output index out of bounds");
229  return shared_from_this();
230  }
231 
232  casadi_int SXNode::eq_depth_ = 1;
233 
235  casadi_error("'serialize_node' not defined for class " + class_name());
236  }
237 
239  s.pack("SXNode::op", op());
240  serialize_node(s);
241  }
242 
244  casadi_int op;
245  s.unpack("SXNode::op", op);
246 
248  return BinarySX::deserialize(s, op);
249  } else if (casadi_math<MX>::is_unary(op)) {
250  return UnarySX::deserialize(s, op);
251  }
252 
253  auto it = SXNode::deserialize_map.find(op);
254  if (it==SXNode::deserialize_map.end()) {
255  casadi_error("Not implemented op " + str(casadi_int(op)));
256  } else {
257  return it->second(s);
258  }
259  }
260 
262  return SXElem(this, false);
263  }
264 
266  return SXElem(const_cast<SXNode*>(this), false);
267  }
268 
269  // Note: binary/unary operations are omitted here
270  std::map<casadi_int, SXNode* (*)(DeserializingStream&)> SXNode::deserialize_map = {
274  {-1, OutputSX::deserialize}};
275 
276 
277 } // namespace casadi
static SXNode * deserialize(DeserializingStream &s, casadi_int op)
Deserialize without type information.
Definition: binary_sx.hpp:143
static SXNode * deserialize(DeserializingStream &s)
Definition: call_sx.hpp:161
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Function object.
Definition: function.hpp:60
const std::string & name() const
Name of the function.
Definition: function.cpp:1307
casadi_int n_in() const
Get the number of function inputs.
Definition: function.cpp:819
casadi_int nnz_in() const
Get number of input nonzeros.
Definition: function.cpp:851
static SXNode * deserialize(DeserializingStream &s)
Definition: output_sx.hpp:152
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
SXNode * assignNoDelete(const SXElem &scalar)
Assign the node to something, without invoking the deletion of the node,.
Definition: sx_elem.cpp:118
Internal node class for SX.
Definition: sx_node.hpp:49
void serialize(SerializingStream &s) const
Serialize an object.
Definition: sx_node.cpp:238
virtual SXElem get_output(casadi_int oind) const
Get an output.
Definition: sx_node.cpp:227
virtual const SXElem & dep(casadi_int i) const
get the reference of a child
Definition: sx_node.cpp:80
virtual Function which_function() const
Get called function.
Definition: sx_node.cpp:64
SXElem shared_from_this()
Get a shared object from the current internal object.
Definition: sx_node.cpp:261
static std::map< casadi_int, SXNode *(*)(DeserializingStream &)> deserialize_map
Definition: sx_node.hpp:209
static SXNode * deserialize(DeserializingStream &s)
Definition: sx_node.cpp:243
virtual casadi_int n_dep() const
Number of dependencies.
Definition: sx_node.hpp:124
static void safe_delete(SXNode *n)
Non-recursive delete.
Definition: sx_node.cpp:184
virtual casadi_int to_int() const
Get value of a constant node.
Definition: sx_node.cpp:60
unsigned int count
Definition: sx_node.hpp:197
virtual bool is_equal(const SXNode *node, casadi_int depth) const
Check if two nodes are equivalent up to a given depth.
Definition: sx_node.cpp:72
virtual const std::string & name() const
Definition: sx_node.cpp:76
virtual std::string print(const std::string &arg1, const std::string &arg2) const =0
Print expression.
virtual ~SXNode()
destructor
Definition: sx_node.cpp:45
virtual double to_double() const
Get value of a constant node.
Definition: sx_node.cpp:56
static casadi_int eq_depth_
Definition: sx_node.hpp:179
void mark() const
Definition: sx_node.cpp:109
bool marked() const
Definition: sx_node.cpp:105
virtual casadi_int op() const =0
get the operation
virtual void serialize_node(SerializingStream &s) const
Definition: sx_node.cpp:234
std::string print_compact(std::map< const SXNode *, casadi_int > &nodeind, std::vector< std::string > &intermed) const
Print compact.
Definition: sx_node.cpp:130
void can_inline(std::map< const SXNode *, casadi_int > &nodeind) const
Find out which nodes can be inlined.
Definition: sx_node.cpp:113
virtual void disp(std::ostream &stream, bool more) const
print
Definition: sx_node.cpp:88
friend class SXElem
Definition: sx_node.hpp:50
virtual casadi_int which_output() const
Get function output.
Definition: sx_node.cpp:68
SXNode()
constructor
Definition: sx_node.cpp:40
virtual std::string class_name() const =0
Get type name.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
static SXNode * deserialize(DeserializingStream &s)
Definition: symbolic_sx.hpp:77
static SXNode * deserialize(DeserializingStream &s, casadi_int op)
Definition: unary_sx.hpp:127
casadi_limits class
The casadi namespace.
Definition: archiver.cpp:28
std::string str(const T &v)
String representation, any type.
SXNode * ConstantSX_deserialize(DeserializingStream &s)
@ OP_CONST
Definition: calculus.hpp:79
@ OP_PARAMETER
Definition: calculus.hpp:85
@ OP_CALL
Definition: calculus.hpp:88
Easy access to all the functions for a particular type.
Definition: calculus.hpp:1125