diff --git a/tests/common/plugins/ptfadapter/__init__.py b/tests/common/plugins/ptfadapter/__init__.py index 24482162b56..d1186f8ee77 100644 --- a/tests/common/plugins/ptfadapter/__init__.py +++ b/tests/common/plugins/ptfadapter/__init__.py @@ -3,12 +3,51 @@ import pytest from ptfadapter import PtfTestAdapter +import ptf.testutils + DEFAULT_PTF_NN_PORT = 10900 DEFAULT_DEVICE_NUM = 0 ETH_PFX = 'eth' +def pytest_addoption(parser): + parser.addoption("--keep_payload", action="store_true", default=False, + help="Keep the original packet payload, do not update payload to default pattern") + + +def override_ptf_functions(): + # Below code is to override the 'send' function in the ptf.testutils module. Purpose of this change is to insert + # code for updating the packet pattern before send it out. Generally we want to make the payload part of injected + # packet to have string of current test module and case name. While inspecting the captured packets, it is easier + # to fiture out which packets are injected by which test case. + def _send(test, port_id, pkt, count=1): + update_payload = getattr(test, "update_payload", None) + if update_payload and callable(update_payload): + pkt = test.update_payload(pkt) + + return ptf.testutils.send_packet(test, port_id, pkt, count=count) + setattr(ptf.testutils, "send", _send) + + + # Below code is to override the 'dp_poll' function in the ptf.testutils module. This function is called by all + # the other functions for receiving packets in the ptf.testutils module. Purpose of this overriding is to update + # the payload of received packet using the same method to match the updated injected packets. + def _dp_poll(test, device_number=0, port_number=None, timeout=-1, exp_pkt=None): + update_payload = getattr(test, "update_payload", None) + if update_payload and callable(update_payload): + exp_pkt = test.update_payload(exp_pkt) + + result = test.dataplane.poll( + device_number=device_number, port_number=port_number, + timeout=timeout, exp_pkt=exp_pkt, filters=ptf.testutils.FILTERS + ) + if isinstance(result, test.dataplane.PollSuccess): + test.at_receive(result.packet, device_number=result.device, port_number=result.port) + return result + setattr(ptf.testutils, "dp_poll", _dp_poll) + + def get_ifaces(netdev_output): """ parse /proc/net/dev content :param netdev_output: content of /proc/net/dev @@ -33,7 +72,7 @@ def get_ifaces(netdev_output): @pytest.fixture(scope='module') -def ptfadapter(ptfhost, testbed): +def ptfadapter(ptfhost, testbed, request): """return ptf test adapter object. The fixture is module scope, because usually there is not need to restart PTF nn agent and reinitialize data plane thread on every @@ -63,5 +102,13 @@ def ptfadapter(ptfhost, testbed): ptfhost.command('supervisorctl reread') ptfhost.command('supervisorctl update') + # Force a restart of ptf_nn_agent to ensure that it is in good status. + ptfhost.command('supervisorctl restart ptf_nn_agent') + with PtfTestAdapter(testbed['ptf_ip'], DEFAULT_PTF_NN_PORT, 0, ifaces_map.keys()) as adapter: + if not request.config.option.keep_payload: + override_ptf_functions() + node_id = request.module.__name__ + adapter.payload_pattern = node_id + " " + yield adapter diff --git a/tests/common/plugins/ptfadapter/ptfadapter.py b/tests/common/plugins/ptfadapter/ptfadapter.py index 3fa346932e5..28fc4f3f82c 100644 --- a/tests/common/plugins/ptfadapter/ptfadapter.py +++ b/tests/common/plugins/ptfadapter/ptfadapter.py @@ -3,6 +3,8 @@ from ptf.dataplane import DataPlane import ptf.platforms.nn as nn import ptf.ptfutils as ptfutils +import ptf.packet as scapy +import ptf.mask as mask class PtfTestAdapter(BaseTest): @@ -22,6 +24,7 @@ def __init__(self, ptf_ip, ptf_nn_port, device_num, ptf_port_set): """ self.runTest = lambda : None # set a no op runTest attribute to satisfy BaseTest interface super(PtfTestAdapter, self).__init__() + self.payload_pattern = "" self._init_ptf_dataplane(ptf_ip, ptf_nn_port, device_num, ptf_port_set) def __enter__(self): @@ -88,3 +91,53 @@ def reinit(self, ptf_config=None): """ self.kill() self._init_ptf_dataplane(self.ptf_ip, self.ptf_nn_port, self.device_num, self.ptf_port_set, ptf_config) + + def update_payload(self, pkt): + """Update the payload of packet to the default pattern when certain conditions are met. + + The packet passed in could be a regular scapy packet or a masked packet. If it is a regular scapy packet and + has UDP or TCP header, then update its TCP or UDP payload. + + If it is a masked packet, then its 'exp_pkt' is the regular scapy packet. Update the payload of its 'exp_pkt' + properly. + + Args: + pkt [scapy packet or masked packet]: The packet to be updated. + + Returns: + [scapy packet or masked packet]: Returns the packet with payload part updated. + """ + if isinstance(pkt, scapy.Ether): + for proto in (scapy.UDP, scapy.TCP): + if proto in pkt: + pkt[proto].load = self._update_payload(pkt[proto].load) + elif isinstance(pkt, mask.Mask): + for proto in (scapy.UDP, scapy.TCP): + if proto in pkt.exp_pkt: + pkt.exp_pkt[proto].load = self._update_payload(pkt.exp_pkt[proto].load) + return pkt + + def _update_payload(self, payload): + """Update payload to the default_pattern if default_pattern is set. + + If length of the payload_pattern is longer payload, truncate payload_pattern to the length of payload. + Otherwise, repeat the payload_pattern to reach the length of payload. Keep length of updated payload same + as the original payload. + + Args: + payload [string]: The payload to be updated. + + Returns: + [string]: The updated payload. + """ + if self.payload_pattern: + len_old = len(payload) + len_new = len(self.payload_pattern) + if len_new >= len_old: + return self.payload_pattern[:len_old] + else: + factor = len_old/len_new + 1 + new_payload = self.payload_pattern * factor + return new_payload[:len_old] + else: + return payload