|
22 | 22 | from tensordict import TensorDict |
23 | 23 |
|
24 | 24 | from verl import DataProto |
25 | | -from verl.protocol import union_numpy_dict, union_tensor_dict |
| 25 | +from verl.protocol import ( |
| 26 | + deserialize_single_tensor, |
| 27 | + deserialize_tensordict, |
| 28 | + serialize_single_tensor, |
| 29 | + serialize_tensordict, |
| 30 | + union_numpy_dict, |
| 31 | + union_tensor_dict, |
| 32 | +) |
26 | 33 |
|
27 | 34 |
|
28 | 35 | def test_union_tensor_dict(): |
@@ -614,3 +621,194 @@ def test_to_tensordict(): |
614 | 621 | assert torch.all(torch.eq(output["obs"], obs)).item() |
615 | 622 | assert output["labels"] == labels |
616 | 623 | assert output["name"] == "abdce" |
| 624 | + |
| 625 | + |
| 626 | +def test_serialize_deserialize_single_tensor(): |
| 627 | + """Test serialization and deserialization of a single tensor""" |
| 628 | + # Create test tensor |
| 629 | + original_tensor = torch.randn(3, 4, 5) |
| 630 | + |
| 631 | + # Serialize |
| 632 | + dtype, shape, data = serialize_single_tensor(original_tensor) |
| 633 | + |
| 634 | + # Deserialize |
| 635 | + reconstructed_tensor = deserialize_single_tensor((dtype, shape, data)) |
| 636 | + |
| 637 | + # Verify results |
| 638 | + assert torch.allclose(original_tensor, reconstructed_tensor) |
| 639 | + assert original_tensor.shape == reconstructed_tensor.shape |
| 640 | + assert original_tensor.dtype == reconstructed_tensor.dtype |
| 641 | + |
| 642 | + |
| 643 | +def test_serialize_deserialize_tensordict_regular_tensors(): |
| 644 | + """Test serialization and deserialization of TensorDict with regular tensors""" |
| 645 | + # Create test data |
| 646 | + batch_size = (5, 3) |
| 647 | + tensor1 = torch.randn(*batch_size, 4) |
| 648 | + tensor2 = torch.randint(0, 10, (*batch_size, 2)) |
| 649 | + |
| 650 | + # Create TensorDict |
| 651 | + original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size) |
| 652 | + |
| 653 | + # Serialize |
| 654 | + batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) |
| 655 | + |
| 656 | + # Deserialize |
| 657 | + reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) |
| 658 | + |
| 659 | + # Verify results |
| 660 | + assert original_tensordict.batch_size == reconstructed_tensordict.batch_size |
| 661 | + assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) |
| 662 | + |
| 663 | + for key in original_tensordict.keys(): |
| 664 | + original_tensor = original_tensordict[key] |
| 665 | + reconstructed_tensor = reconstructed_tensordict[key] |
| 666 | + |
| 667 | + assert torch.allclose(original_tensor, reconstructed_tensor) |
| 668 | + assert original_tensor.shape == reconstructed_tensor.shape |
| 669 | + assert original_tensor.dtype == reconstructed_tensor.dtype |
| 670 | + |
| 671 | + |
| 672 | +def test_serialize_deserialize_tensordict_nested_tensors(): |
| 673 | + """Test serialization and deserialization of TensorDict with nested tensors""" |
| 674 | + # Create nested tensor |
| 675 | + tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)] |
| 676 | + nested_tensor = torch.nested.as_nested_tensor(tensor_list) |
| 677 | + |
| 678 | + # Create regular tensor for comparison |
| 679 | + regular_tensor = torch.randn(3, 4, 5) |
| 680 | + |
| 681 | + # Create TensorDict |
| 682 | + original_tensordict = TensorDict({"nested": nested_tensor, "regular": regular_tensor}, batch_size=(3,)) |
| 683 | + |
| 684 | + # Serialize |
| 685 | + batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) |
| 686 | + |
| 687 | + # Deserialize |
| 688 | + reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) |
| 689 | + |
| 690 | + # Verify results |
| 691 | + assert original_tensordict.batch_size == reconstructed_tensordict.batch_size |
| 692 | + assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) |
| 693 | + |
| 694 | + # Verify regular tensor |
| 695 | + original_regular = original_tensordict["regular"] |
| 696 | + reconstructed_regular = reconstructed_tensordict["regular"] |
| 697 | + |
| 698 | + assert torch.allclose(original_regular, reconstructed_regular) |
| 699 | + assert original_regular.shape == reconstructed_regular.shape |
| 700 | + assert original_regular.dtype == reconstructed_regular.dtype |
| 701 | + |
| 702 | + # Verify nested tensor |
| 703 | + original_nested = original_tensordict["nested"] |
| 704 | + reconstructed_nested = reconstructed_tensordict["nested"] |
| 705 | + |
| 706 | + # Check if it's a nested tensor |
| 707 | + assert original_nested.is_nested |
| 708 | + assert reconstructed_nested.is_nested |
| 709 | + |
| 710 | + # Check layout |
| 711 | + assert original_nested.layout == reconstructed_nested.layout |
| 712 | + |
| 713 | + # Check each tensor after unbinding |
| 714 | + original_unbind = original_nested.unbind() |
| 715 | + reconstructed_unbind = reconstructed_nested.unbind() |
| 716 | + |
| 717 | + assert len(original_unbind) == len(reconstructed_unbind) |
| 718 | + |
| 719 | + for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False): |
| 720 | + assert torch.allclose(orig, recon) |
| 721 | + assert orig.shape == recon.shape |
| 722 | + assert orig.dtype == recon.dtype |
| 723 | + |
| 724 | + |
| 725 | +def test_serialize_deserialize_tensordict_mixed_types(): |
| 726 | + """Test serialization and deserialization of TensorDict with mixed tensor types""" |
| 727 | + # Create tensors with different data types |
| 728 | + float_tensor = torch.randn(2, 3).float() |
| 729 | + double_tensor = torch.randn(2, 3).double() |
| 730 | + int_tensor = torch.randint(0, 10, (2, 3)).int() |
| 731 | + long_tensor = torch.randint(0, 10, (2, 3)).long() |
| 732 | + bool_tensor = torch.tensor([[True, False], [False, True]]) |
| 733 | + |
| 734 | + # Create nested tensor |
| 735 | + tensor_list = [ |
| 736 | + torch.randn(2, 3), |
| 737 | + torch.randn(3, 4), |
| 738 | + ] |
| 739 | + nested_tensor = torch.nested.as_nested_tensor(tensor_list) |
| 740 | + |
| 741 | + # Create TensorDict |
| 742 | + original_tensordict = TensorDict( |
| 743 | + { |
| 744 | + "float": float_tensor, |
| 745 | + "double": double_tensor, |
| 746 | + "int": int_tensor, |
| 747 | + "long": long_tensor, |
| 748 | + "bool": bool_tensor, |
| 749 | + "nested": nested_tensor, |
| 750 | + }, |
| 751 | + batch_size=(2,), |
| 752 | + ) |
| 753 | + |
| 754 | + # Serialize |
| 755 | + batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) |
| 756 | + |
| 757 | + # Deserialize |
| 758 | + reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) |
| 759 | + |
| 760 | + # Verify results |
| 761 | + assert original_tensordict.batch_size == reconstructed_tensordict.batch_size |
| 762 | + assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) |
| 763 | + |
| 764 | + for key in original_tensordict.keys(): |
| 765 | + original_tensor = original_tensordict[key] |
| 766 | + reconstructed_tensor = reconstructed_tensordict[key] |
| 767 | + |
| 768 | + if original_tensor.is_nested: |
| 769 | + # For nested tensors, check each tensor after unbinding |
| 770 | + original_unbind = original_tensor.unbind() |
| 771 | + reconstructed_unbind = reconstructed_tensor.unbind() |
| 772 | + |
| 773 | + assert len(original_unbind) == len(reconstructed_unbind) |
| 774 | + |
| 775 | + for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False): |
| 776 | + assert torch.allclose(orig, recon, equal_nan=True) |
| 777 | + assert orig.shape == recon.shape |
| 778 | + assert orig.dtype == recon.dtype |
| 779 | + else: |
| 780 | + # For regular tensors, compare directly |
| 781 | + assert torch.allclose(original_tensor, reconstructed_tensor, equal_nan=True) |
| 782 | + assert original_tensor.shape == reconstructed_tensor.shape |
| 783 | + assert original_tensor.dtype == reconstructed_tensor.dtype |
| 784 | + |
| 785 | + |
| 786 | +def test_serialize_deserialize_tensordict_with_device(): |
| 787 | + """Test serialization and deserialization of TensorDict with device information""" |
| 788 | + # Create test data |
| 789 | + batch_size = (2, 3) |
| 790 | + tensor1 = torch.randn(*batch_size, 4) |
| 791 | + tensor2 = torch.randint(0, 10, (*batch_size, 2)) |
| 792 | + |
| 793 | + # Create TensorDict with device information |
| 794 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 795 | + original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size, device=device) |
| 796 | + |
| 797 | + # Serialize |
| 798 | + batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict) |
| 799 | + |
| 800 | + # Deserialize |
| 801 | + reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items)) |
| 802 | + |
| 803 | + # Verify results |
| 804 | + assert original_tensordict.batch_size == reconstructed_tensordict.batch_size |
| 805 | + assert str(original_tensordict.device) == str(reconstructed_tensordict.device) |
| 806 | + assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) |
| 807 | + |
| 808 | + for key in original_tensordict.keys(): |
| 809 | + original_tensor = original_tensordict[key] |
| 810 | + reconstructed_tensor = reconstructed_tensordict[key] |
| 811 | + |
| 812 | + assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu()) |
| 813 | + assert original_tensor.shape == reconstructed_tensor.shape |
| 814 | + assert original_tensor.dtype == reconstructed_tensor.dtype |
0 commit comments