mf
Media Framework
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
ndarray_view_cast.h
Go to the documentation of this file.
1 #ifndef MF_NDARRAY_VIEW_CAST_H_
2 #define MF_NDARRAY_VIEW_CAST_H_
3 
4 #include "ndarray_view.h"
5 #include "../common.h"
6 #include "../elem.h"
7 #include "../elem_tuple.h"
8 #include "../masked_elem.h"
9 #include "../utility/misc.h"
10 #include <cstddef>
11 
12 namespace mf {
13 
14 namespace detail {
15  template<typename Output_view, typename Input_view>
16  struct ndarray_view_caster;
17 
18  // single element from tuple
19  template<typename Output_elem, std::size_t Dim, typename... Input_elems>
20  struct ndarray_view_caster<
21  ndarray_view<Dim, Output_elem>, // out
22  ndarray_view<Dim, elem_tuple<Input_elems...>> // in
23  >{
24  using input_tuple_type = elem_tuple<Input_elems...>;
25 
26  using output_view_type = ndarray_view<Dim, Output_elem>;
27  using input_view_type = ndarray_view<Dim, input_tuple_type>;
28 
29  output_view_type operator()(const input_view_type& arr) const {
30  constexpr std::ptrdiff_t index = elem_tuple_index<Output_elem, input_tuple_type>;
31  constexpr std::ptrdiff_t offset = elem_tuple_offset<index, input_tuple_type>;
32 
33  auto* start = reinterpret_cast<Output_elem*>(
34  advance_raw_ptr(arr.start(), offset)
35  );
36  return output_view_type(
37  start,
38  arr.shape(),
39  arr.strides()
40  );
41  }
42  };
43 
44  // scalars from elem
45  template<std::size_t Dim, typename Input_elem>
46  struct ndarray_view_caster<
47  ndarray_view<Dim + 1, typename elem_traits<Input_elem>::scalar_type>, // out
48  ndarray_view<Dim, Input_elem> // in
49  >{
50  using elem_traits_type = elem_traits<Input_elem>;
51  using elem_scalar_type = typename elem_traits_type::scalar_type;
52 
53  using output_view_type = ndarray_view<Dim + 1, elem_scalar_type>;
54  using input_view_type = ndarray_view<Dim, Input_elem>;
55 
56  output_view_type operator()(const input_view_type& arr) const {
57  auto* start = reinterpret_cast<elem_scalar_type*>(arr.start());
58  return output_view_type(
59  start,
60  ndcoord_cat(arr.shape(), make_ndsize(elem_traits_type::components)),
61  ndcoord_cat(arr.strides(), make_ndptrdiff(elem_traits_type::stride))
62  );
63  }
64  };
65 
66  // masked to elem cast
67  template<std::size_t Dim, typename Elem>
68  struct ndarray_view_caster<
69  ndarray_view<Dim, Elem>, // out
70  ndarray_view<Dim, masked_elem<Elem>> // in
71  >{
72  using input_view_type = ndarray_view<Dim, masked_elem<Elem>>;
73  using output_view_type = ndarray_view<Dim, Elem>;
74 
75  output_view_type operator()(const input_view_type& arr) const {
76  auto* start = reinterpret_cast<Elem*>(arr.start());
77  return output_view_type(start, arr.shape(), arr.strides());
78  }
79  };
80 
81 
82  // masked to mask cast
83  template<std::size_t Dim, typename Elem>
84  struct ndarray_view_caster<
85  ndarray_view<Dim, bool>, // out
86  ndarray_view<Dim, masked_elem<Elem>> // in
87  >{
88  using input_view_type = ndarray_view<Dim, masked_elem<Elem>>;
89  using output_view_type = ndarray_view<Dim, bool>;
90 
91  output_view_type operator()(const input_view_type& arr) const {
92  std::ptrdiff_t offset = offsetof(masked_elem<Elem>, mask);
93  auto* start = reinterpret_cast<bool*>(
94  advance_raw_ptr(arr.start(), offset)
95  );
96  return output_view_type(start, arr.shape(), arr.strides());
97  }
98  };
99 
100 
101  // no-op cast
102  template<std::size_t Dim, typename Elem>
103  struct ndarray_view_caster<
104  ndarray_view<Dim, Elem>, // out
105  ndarray_view<Dim, Elem> // in
106  >{
107  using view_type = ndarray_view<Dim, Elem>;
108  const view_type& operator()(const view_type& arr) const {
109  return arr;
110  }
111  };
112 }
113 
114 
115 template<typename Output_view, typename Input_view>
116 Output_view ndarray_view_cast(const Input_view& view) {
117  detail::ndarray_view_caster<Output_view, Input_view> caster;
118  return caster(view);
119 }
120 
121 
122 template<typename Output_view, typename Input_view>
123 Output_view ndarray_view_reinterpret_cast(const Input_view& in_view) {
124  using in_elem_type = typename Input_view::value_type;
125  using out_elem_type = typename Output_view::value_type;
126  static_assert(Output_view::dimension == Input_view::dimension, "output and input view must have same dimension");
127  std::ptrdiff_t in_stride = in_view.strides().back();
128  if(in_stride < sizeof(out_elem_type))
129  throw std::invalid_argument("output ndarray_view elem type is too large");
130  if(in_stride % alignof(out_elem_type) != 0)
131  throw std::invalid_argument("output ndarray_view elem type has incompatible alignment");
132 
133  auto new_start = reinterpret_cast<out_elem_type*>(in_view.start());
134  return Output_view(new_start, in_view.shape(), in_view.strides());
135 }
136 
137 
138 
139 }
140 
141 #endif
142 
Output_view ndarray_view_cast(const Input_view &view)
Definition: ndarray_view_cast.h:116
auto make_ndsize(Components...c)
Definition: ndcoord.h:187
ndcoord< Dim1+Dim2, T > ndcoord_cat(const ndcoord< Dim1, T > &coord1, const ndcoord< Dim2, T > &coord2)
Definition: ndcoord.h:198
auto make_ndptrdiff(Components...c)
Definition: ndcoord.h:192
T * advance_raw_ptr(T *ptr, std::ptrdiff_t diff) noexcept
Advance a pointer ptr of any type by diff bytes.
Definition: misc.tcc:9
Output_view ndarray_view_reinterpret_cast(const Input_view &in_view)
Definition: ndarray_view_cast.h:123