map.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_MAP_HPP
27 #define CASADI_MAP_HPP
28 
29 #include "function_internal.hpp"
30 
32 
33 namespace casadi {
34 
39  class CASADI_EXPORT Map : public FunctionInternal {
40  public:
41  // Create function (use instead of constructor)
42  static Function create(const std::string& parallelization,
43  const Function& f, casadi_int n);
44 
48  ~Map() override;
49 
53  std::string class_name() const override {return "Map";}
54 
58  bool is_a(const std::string& type, bool recursive) const override;
59 
60  // Get list of dependency functions
61  virtual std::vector<std::string> get_function() const override;
62 
63  // Get a dependency function
64  const Function& get_function(const std::string &name) const override;
65 
66  // Check if a particular dependency exists
67  bool has_function(const std::string& fname) const override;
68 
70 
73  Sparsity get_sparsity_in(casadi_int i) override {
74  return repmat(f_.sparsity_in(i), 1, n_);
75  }
76  Sparsity get_sparsity_out(casadi_int i) override {
77  return repmat(f_.sparsity_out(i), 1, n_);
78  }
80 
84  double get_default_in(casadi_int ind) const override { return f_.default_in(ind);}
85 
87 
90  size_t get_n_in() override { return f_.n_in();}
91  size_t get_n_out() override { return f_.n_out();}
93 
95 
98  std::string get_name_in(casadi_int i) override { return f_.name_in(i);}
99  std::string get_name_out(casadi_int i) override { return f_.name_out(i);}
101 
105  template<typename T>
106  int eval_gen(const T** arg, T** res, casadi_int* iw, T* w, int mem=0) const;
107 
109  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
110 
112  virtual std::string parallelization() const { return "serial"; }
113 
117  int eval_sx(const SXElem** arg, SXElem** res,
118  casadi_int* iw, SXElem* w, void* mem) const override;
119 
123  int sp_forward(const bvec_t** arg, bvec_t** res,
124  casadi_int* iw, bvec_t* w, void* mem) const override;
125 
129  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
130 
133  bool has_spfwd() const override { return true;}
134  bool has_sprev() const override { return true;}
136 
140  bool has_codegen() const override { return true;}
141 
145  void codegen_declarations(CodeGenerator& g) const override;
146 
150  void codegen_body(CodeGenerator& g) const override;
151 
155  void init(const Dict& opts) override;
156 
158 
161  bool has_forward(casadi_int nfwd) const override { return true;}
162  Function get_forward(casadi_int nfwd, const std::string& name,
163  const std::vector<std::string>& inames,
164  const std::vector<std::string>& onames,
165  const Dict& opts) const override;
167 
169 
172  bool has_reverse(casadi_int nadj) const override { return true;}
173  Function get_reverse(casadi_int nadj, const std::string& name,
174  const std::vector<std::string>& inames,
175  const std::vector<std::string>& onames,
176  const Dict& opts) const override;
178 
180  Dict info() const override { return {{"f", f_}, {"n", n_}}; }
181 
185  void serialize_body(SerializingStream &s) const override;
189  void serialize_type(SerializingStream &s) const override;
190 
194  std::string serialize_base_function() const override { return "Map"; }
195 
199  static ProtoFunction* deserialize(DeserializingStream& s);
200 
201  protected:
205  explicit Map(DeserializingStream& s);
206 
207  // Constructor (protected, use create function)
208  Map(const std::string& name, const Function& f, casadi_int n);
209 
210  // The function which is to be evaluated in parallel
211  Function f_;
212 
213  // Number of times to evaluate this function
214  casadi_int n_;
215  };
216 
224  class CASADI_EXPORT OmpMap : public Map {
225  friend class Map;
226  public:
227  // Constructor (protected, use create function in Map)
228  OmpMap(const std::string& name, const Function& f, casadi_int n) : Map(name, f, n) {}
229 
233  ~OmpMap() override;
234 
238  std::string class_name() const override {return "OmpMap";}
239 
243  bool is_a(const std::string& type, bool recursive) const override;
244 
246  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
247 
251  void init(const Dict& opts) override;
252 
254  std::string parallelization() const override { return "openmp"; }
255 
259  void codegen_body(CodeGenerator& g) const override;
260 
261  protected:
265  explicit OmpMap(DeserializingStream& s) : Map(s) {}
266  };
267 
275  class CASADI_EXPORT ThreadMap : public Map {
276  friend class Map;
277  public:
278  // Constructor (protected, use create function in Map)
279  ThreadMap(const std::string& name, const Function& f, casadi_int n) : Map(name, f, n) {}
280 
284  ~ThreadMap() override;
285 
289  std::string class_name() const override {return "ThreadMap";}
290 
294  bool is_a(const std::string& type, bool recursive) const override;
295 
297  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
298 
302  void init(const Dict& opts) override;
303 
305  std::string parallelization() const override { return "thread"; }
306 
310  void codegen_body(CodeGenerator& g) const override;
311 
312  protected:
316  explicit ThreadMap(DeserializingStream& s) : Map(s) {}
317  };
318 
319 } // namespace casadi
321 
322 #endif // CASADI_MAP_HPP
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.