1use ndarray::prelude::*;
66use ndarray::{Data, ShapeBuilder};
67
68use crate::prelude::{c64, dim_symbol, Rcplx, Rfloat, Rint};
69use crate::*;
70
71macro_rules! make_array_view_1 {
72 ($type: ty, $error_fn: expr) => {
73 impl<'a> TryFrom<&'_ Robj> for ArrayView1<'a, $type> {
74 type Error = crate::Error;
75
76 fn try_from(robj: &Robj) -> Result<Self> {
77 if let Some(v) = robj.as_typed_slice() {
78 Ok(ArrayView1::<'a, $type>::from(v))
79 } else {
80 Err($error_fn(robj.clone()))
81 }
82 }
83 }
84
85 impl<'a> TryFrom<Robj> for ArrayView1<'a, $type> {
86 type Error = crate::Error;
87
88 fn try_from(robj: Robj) -> Result<Self> {
89 Self::try_from(&robj)
90 }
91 }
92 };
93}
94
95macro_rules! make_array_view_2 {
96 ($type: ty, $error_str: expr, $error_fn: expr) => {
97 impl<'a> TryFrom<&'_ Robj> for ArrayView2<'a, $type> {
98 type Error = crate::Error;
99 fn try_from(robj: &Robj) -> Result<Self> {
100 if robj.is_matrix() {
101 let nrows = robj.nrows();
102 let ncols = robj.ncols();
103 if let Some(v) = robj.as_typed_slice() {
104 let shape = (nrows, ncols).into_shape_with_order().f();
106 return ArrayView2::from_shape(shape, v)
107 .map_err(|err| Error::NDArrayShapeError(err));
108 } else {
109 return Err($error_fn(robj.clone()));
110 }
111 }
112 return Err(Error::ExpectedMatrix(robj.clone()));
113 }
114 }
115
116 impl<'a> TryFrom<Robj> for ArrayView2<'a, $type> {
117 type Error = crate::Error;
118 fn try_from(robj: Robj) -> Result<Self> {
119 Self::try_from(&robj)
120 }
121 }
122 };
123}
124make_array_view_1!(Rbool, Error::ExpectedLogical);
125make_array_view_1!(Rint, Error::ExpectedInteger);
126make_array_view_1!(i32, Error::ExpectedInteger);
127make_array_view_1!(Rfloat, Error::ExpectedReal);
128make_array_view_1!(f64, Error::ExpectedReal);
129make_array_view_1!(Rcplx, Error::ExpectedComplex);
130make_array_view_1!(c64, Error::ExpectedComplex);
131make_array_view_1!(Rstr, Error::ExpectedString);
132
133make_array_view_2!(Rbool, "Not a logical matrix.", Error::ExpectedLogical);
134make_array_view_2!(Rint, "Not an integer matrix.", Error::ExpectedInteger);
135make_array_view_2!(i32, "Not an integer matrix.", Error::ExpectedInteger);
136make_array_view_2!(Rfloat, "Not a floating point matrix.", Error::ExpectedReal);
137make_array_view_2!(f64, "Not a floating point matrix.", Error::ExpectedReal);
138make_array_view_2!(
139 Rcplx,
140 "Not a complex number matrix.",
141 Error::ExpectedComplex
142);
143make_array_view_2!(c64, "Not a complex number matrix.", Error::ExpectedComplex);
144make_array_view_2!(Rstr, "Not a string matrix.", Error::ExpectedString);
145
146impl<A, S, D> TryFrom<&ArrayBase<S, D>> for Robj
147where
148 S: Data<Elem = A>,
149 A: Copy + ToVectorValue,
150 D: Dimension,
151{
152 type Error = Error;
153
154 fn try_from(value: &ArrayBase<S, D>) -> Result<Self> {
157 let mut result = value
164 .t()
165 .iter()
166 .copied()
168 .collect_robj();
169 result.set_attrib(
170 dim_symbol(),
171 value
172 .shape()
173 .iter()
174 .map(|x| i32::try_from(*x))
175 .collect::<std::result::Result<Vec<i32>, <i32 as TryFrom<usize>>::Error>>()
176 .map_err(|_err| {
177 Error::Other(String::from(
178 "One or more array dimensions were too large to be handled by R.",
179 ))
180 })?,
181 )?;
182 Ok(result)
183 }
184}
185
186impl<A, S, D> TryFrom<ArrayBase<S, D>> for Robj
187where
188 S: Data<Elem = A>,
189 A: Copy + ToVectorValue,
190 D: Dimension,
191{
192 type Error = Error;
193
194 fn try_from(value: ArrayBase<S, D>) -> Result<Self> {
197 Robj::try_from(&value)
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use super::*;
204 use crate as extendr_api;
205 use ndarray::array;
206 use rstest::rstest;
207
208 #[rstest]
209 #[case(
211 "1.0",
212 ArrayView1::<f64>::from(&[1.][..])
213 )]
214 #[case(
215 "1L",
216 ArrayView1::<i32>::from(&[1][..])
217 )]
218 #[case(
219 "TRUE",
220 ArrayView1::<Rbool>::from(&[TRUE][..])
221 )]
222 #[case(
224 "matrix(c(1, 2, 3, 4, 5, 6, 7, 8), ncol=2, nrow=4)",
225 <Array2<f64>>::from_shape_vec((4, 2).f(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap()
226 )]
227 #[case(
228 "matrix(c(1, 2, 3, 4, 5, 6, 7, 8), ncol=2, nrow=4)[, 1]",
230 <Array2<f64>>::from_shape_vec((4, 2).f(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap().column(0).to_owned()
231 )]
232 #[case(
233 "matrix(c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L), ncol=2, nrow=4)",
234 <Array2<i32>>::from_shape_vec((4, 2).f(), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap()
235 )]
236 #[case(
237 "matrix(c(T, T, T, T, F, F, F, F), ncol=2, nrow=4)",
238 <Array2<Rbool>>::from_shape_vec((4, 2).f(), vec![true.into(), true.into(), true.into(), true.into(), false.into(), false.into(), false.into(), false.into()]).unwrap()
239 )]
240 fn test_from_robj<DataType, DimType, Error>(
241 #[case] left: &'static str,
242 #[case] right: ArrayBase<DataType, DimType>,
243 ) where
244 DataType: Data,
245 Error: std::fmt::Debug,
246 for<'a> ArrayView<'a, <DataType as ndarray::RawData>::Elem, DimType>:
247 TryFrom<&'a Robj, Error = Error>,
248 DimType: Dimension,
249 <DataType as ndarray::RawData>::Elem: PartialEq + std::fmt::Debug,
250 Error: std::fmt::Debug,
251 {
252 test! {
254 let left_robj = eval_string(left).unwrap();
255 let left_array = <ArrayView<DataType::Elem, DimType>>::try_from(&left_robj).unwrap();
256 assert_eq!( left_array, right );
257 }
258 }
259
260 #[rstest]
261 #[case(
262 Array4::<i32>::zeros((0, 1, 2, 3).f()),
264 "array(integer(), c(0, 1, 2, 3))"
265 )]
266 #[case(
267 array![1., 2., 3.],
268 "array(c(1, 2, 3))"
269 )]
270 #[case(
271 Array::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]).unwrap(),
274 "matrix(c(1, 2, 3, 4, 5, 6), nrow=2, byrow=TRUE)"
275 )]
276 #[case(
277 Array::from_shape_vec((2, 3).f(), vec![1., 2., 3., 4., 5., 6.]).unwrap(),
280 "matrix(c(1, 2, 3, 4, 5, 6), nrow=2, byrow=FALSE)"
281 )]
282 #[case(
283 Array::from_shape_vec((1, 2, 3).f(), vec![1, 2, 3, 4, 5, 6]).unwrap(),
286 "array(1:6, c(1, 2, 3))"
287 )]
288 #[case(
289 array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]],
292 "array(1:8, dim=c(2, 2, 2))"
293 )]
294 fn test_to_robj<ElementType, DimType>(
295 #[case] array: Array<ElementType, DimType>,
296 #[case] r_expr: &str,
297 ) where
298 Robj: TryFrom<Array<ElementType, DimType>>,
299 for<'a> Robj: TryFrom<&'a Array<ElementType, DimType>>,
300 <robj::Robj as TryFrom<Array<ElementType, DimType>>>::Error: std::fmt::Debug,
301 for<'a> <robj::Robj as TryFrom<&'a Array<ElementType, DimType>>>::Error: std::fmt::Debug,
302 {
303 test! {
306 assert_eq!(
308 &(Robj::try_from(&array).unwrap()),
309 &eval_string(r_expr).unwrap()
310 );
311 assert_eq!(
313 &(Robj::try_from(array).unwrap()),
314 &eval_string(r_expr).unwrap()
315 );
316 }
317 }
318
319 #[test]
320 fn test_round_trip() {
321 test! {
322 let rvals = [
323 R!("matrix(c(1L, 2L, 3L, 4L, 5L, 6L), nrow=2)"),
324 R!("array(1:8, c(4, 2))")
325 ];
326 for rval in rvals {
327 let rval = rval.unwrap();
328 let rust_arr= <ArrayView2<i32>>::try_from(&rval).unwrap();
329 let r_arr: Robj = (&rust_arr).try_into().unwrap();
330 assert_eq!(
331 rval,
332 r_arr
333 );
334 }
335 }
336 }
337}