1 #ifndef MF_NDARRAY_VIEW_CAST_H_
2 #define MF_NDARRAY_VIEW_CAST_H_
7 #include "../elem_tuple.h"
8 #include "../masked_elem.h"
9 #include "../utility/misc.h"
15 template<
typename Output_view,
typename Input_view>
16 struct ndarray_view_caster;
19 template<
typename Output_elem, std::size_t Dim,
typename... Input_elems>
20 struct ndarray_view_caster<
21 ndarray_view<Dim, Output_elem>,
22 ndarray_view<Dim, elem_tuple<Input_elems...>>
24 using input_tuple_type = elem_tuple<Input_elems...>;
26 using output_view_type = ndarray_view<Dim, Output_elem>;
27 using input_view_type = ndarray_view<Dim, input_tuple_type>;
29 output_view_type operator()(
const input_view_type& arr)
const {
30 constexpr std::ptrdiff_t index = elem_tuple_index<Output_elem, input_tuple_type>;
31 constexpr std::ptrdiff_t offset = elem_tuple_offset<index, input_tuple_type>;
33 auto* start =
reinterpret_cast<Output_elem*
>(
36 return output_view_type(
45 template<std::
size_t Dim,
typename Input_elem>
46 struct ndarray_view_caster<
47 ndarray_view<Dim + 1, typename elem_traits<Input_elem>::scalar_type>,
48 ndarray_view<Dim, Input_elem>
50 using elem_traits_type = elem_traits<Input_elem>;
51 using elem_scalar_type =
typename elem_traits_type::scalar_type;
53 using output_view_type = ndarray_view<Dim + 1, elem_scalar_type>;
54 using input_view_type = ndarray_view<Dim, Input_elem>;
56 output_view_type operator()(
const input_view_type& arr)
const {
57 auto* start =
reinterpret_cast<elem_scalar_type*
>(arr.start());
58 return output_view_type(
67 template<std::
size_t Dim,
typename Elem>
68 struct ndarray_view_caster<
69 ndarray_view<Dim, Elem>,
70 ndarray_view<Dim, masked_elem<Elem>>
72 using input_view_type = ndarray_view<Dim, masked_elem<Elem>>;
73 using output_view_type = ndarray_view<Dim, Elem>;
75 output_view_type operator()(
const input_view_type& arr)
const {
76 auto* start =
reinterpret_cast<Elem*
>(arr.start());
77 return output_view_type(start, arr.shape(), arr.strides());
83 template<std::
size_t Dim,
typename Elem>
84 struct ndarray_view_caster<
85 ndarray_view<Dim, bool>,
86 ndarray_view<Dim, masked_elem<Elem>>
88 using input_view_type = ndarray_view<Dim, masked_elem<Elem>>;
89 using output_view_type = ndarray_view<Dim, bool>;
91 output_view_type operator()(
const input_view_type& arr)
const {
92 std::ptrdiff_t offset = offsetof(masked_elem<Elem>, mask);
93 auto* start =
reinterpret_cast<bool*
>(
96 return output_view_type(start, arr.shape(), arr.strides());
102 template<std::
size_t Dim,
typename Elem>
103 struct ndarray_view_caster<
104 ndarray_view<Dim, Elem>,
105 ndarray_view<Dim, Elem>
107 using view_type = ndarray_view<Dim, Elem>;
108 const view_type& operator()(
const view_type& arr)
const {
115 template<
typename Output_view,
typename Input_view>
117 detail::ndarray_view_caster<Output_view, Input_view> caster;
122 template<
typename Output_view,
typename Input_view>
124 using in_elem_type =
typename Input_view::value_type;
125 using out_elem_type =
typename Output_view::value_type;
126 static_assert(Output_view::dimension == Input_view::dimension,
"output and input view must have same dimension");
127 std::ptrdiff_t in_stride = in_view.strides().back();
128 if(in_stride <
sizeof(out_elem_type))
129 throw std::invalid_argument(
"output ndarray_view elem type is too large");
130 if(in_stride %
alignof(out_elem_type) != 0)
131 throw std::invalid_argument(
"output ndarray_view elem type has incompatible alignment");
133 auto new_start =
reinterpret_cast<out_elem_type*
>(in_view.start());
134 return Output_view(new_start, in_view.shape(), in_view.strides());
Output_view ndarray_view_cast(const Input_view &view)
Definition: ndarray_view_cast.h:116
auto make_ndsize(Components...c)
Definition: ndcoord.h:187
ndcoord< Dim1+Dim2, T > ndcoord_cat(const ndcoord< Dim1, T > &coord1, const ndcoord< Dim2, T > &coord2)
Definition: ndcoord.h:198
auto make_ndptrdiff(Components...c)
Definition: ndcoord.h:192
T * advance_raw_ptr(T *ptr, std::ptrdiff_t diff) noexcept
Advance a pointer ptr of any type by diff bytes.
Definition: misc.tcc:9
Output_view ndarray_view_reinterpret_cast(const Input_view &in_view)
Definition: ndarray_view_cast.h:123