-
Notifications
You must be signed in to change notification settings - Fork 3k
Preserve dtype for numpy/torch/tf/jax arrays #2361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @lhoestq, |
lhoestq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! To fix this I added a comment to actually keep the numpy array unchanged until it is passed to a TypedSequence. This way we don't have to deal with the ListArray issue
|
Brought down the failing tests from 7 to 4. Let me know if that part looks good. Failing tests are looking quite similar. In datasets/tests/test_arrow_dataset.py Line 1039 in 3d46bc3
test_map_tfdatasets/tests/test_arrow_dataset.py Line 1056 in 3d46bc3
they're expecting float64. Shouldn't that be float32 now?
|
|
It's normal: pytorch and tensorflow use I think that we should always keep the precision of the original tensor (torch/tf/numpy). This is a breaking change but in my opinion the fact that we had Value("float64") for torch.float32 tensors was an issue already. Let me know what you think. Cc @albertvillanova if you have an opinion on this If we agree on doing this breaking change, we can just change the test. |
|
Hi @lhoestq,
|
Yes feel free to update those tests :) It would be nice to have the same test for JAX as well |
|
Added same test for for JAX too. Also, I saw that I missed changing |
lhoestq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks all good !
Thanks a lot :)
Preserve dtype for numpy/torch/tf/jax arrays (huggingface#2361)
Fixes #625. This lets the user preserve the dtype of numpy array to pyarrow array which was getting lost due to conversion of numpy array -> list -> pyarrow array.