diff --git a/config/main.py b/config/main.py index c0a4c500fd..b4f65da6df 100644 --- a/config/main.py +++ b/config/main.py @@ -804,21 +804,26 @@ def validate_mirror_session_config(config_db, session_name, dst_port, src_port, return True -def validate_ip_mask(ctx, ip_addr): +def is_valid_ip_interface(ctx, ip_addr): split_ip_mask = ip_addr.split("/") + if len(split_ip_mask) < 2: + return False + # Check if the IP address is correct or if there are leading zeros. ip_obj = ipaddress.ip_address(split_ip_mask[0]) - # Check if the mask is correct - mask_range = 33 if isinstance(ip_obj, ipaddress.IPv4Address) else 129 - # If mask is not specified - if len(split_ip_mask) < 2: - return 0 + if isinstance(ip_obj, ipaddress.IPv4Address): + # Since the IP address is used as a part of a key in Redis DB, + # do not tolerate extra zeros in IPv4. + if str(ip_obj) != split_ip_mask[0]: + return False - if not int(split_ip_mask[1]) in range(1, mask_range): - return 0 + # Check if the mask is correct + net = ipaddress.ip_network(ip_addr, strict=False) + if str(net.prefixlen) != split_ip_mask[1] or net.prefixlen == 0: + return False - return str(ip_obj) + '/' + str(int(split_ip_mask[1])) + return True def cli_sroute_to_config(ctx, command_str, strict_nh = True): if len(command_str) < 2 or len(command_str) > 9: @@ -3587,10 +3592,9 @@ def add(ctx, interface_name, ip_addr, gw): try: net = ipaddress.ip_network(ip_addr, strict=False) if '/' not in ip_addr: - ip_addr = str(net) + ip_addr += '/' + str(net.prefixlen) - ip_addr = validate_ip_mask(ctx, ip_addr) - if not ip_addr: + if not is_valid_ip_interface(ctx, ip_addr): raise ValueError('') if interface_name == 'eth0': @@ -3652,10 +3656,9 @@ def remove(ctx, interface_name, ip_addr): try: net = ipaddress.ip_network(ip_addr, strict=False) if '/' not in ip_addr: - ip_addr = str(net) - - ip_addr = validate_ip_mask(ctx, ip_addr) - if not ip_addr: + ip_addr += '/' + str(net.prefixlen) + + if not is_valid_ip_interface(ctx, ip_addr): raise ValueError('') if interface_name == 'eth0': diff --git a/tests/ip_config_test.py b/tests/ip_config_test.py index d08a03ca8f..24db7c319a 100644 --- a/tests/ip_config_test.py +++ b/tests/ip_config_test.py @@ -26,7 +26,7 @@ def test_add_del_interface_valid_ipv4(self): obj = {'config_db':db.cfgdb} # config int ip add Ethernet64 10.10.10.1/24 - result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet64", "10.10.10.1/24"], obj=obj) + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet64", "10.10.10.1/24"], obj=obj) print(result.exit_code, result.output) assert result.exit_code == 0 assert ('Ethernet64', '10.10.10.1/24') in db.cfgdb.get_table('INTERFACE') @@ -59,7 +59,7 @@ def test_add_interface_ipv4_invalid_mask(self): assert result.exit_code != 0 assert ERROR_MSG in result.output - def test_add_del_interface_ipv4_with_leading_zeros(self): + def test_add_interface_ipv4_with_leading_zeros(self): db = Db() runner = CliRunner() obj = {'config_db':db.cfgdb} @@ -67,14 +67,8 @@ def test_add_del_interface_ipv4_with_leading_zeros(self): # config int ip add Ethernet68 10.10.10.002/24 result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet68", "10.10.10.002/24"], obj=obj) print(result.exit_code, result.output) - assert result.exit_code == 0 - assert ('Ethernet68', '10.10.10.2/24') in db.cfgdb.get_table('INTERFACE') - - # config int ip remove Ethernet68 10.10.10.002/24 - result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet68", "10.10.10.002/24"], obj=obj) - print(result.exit_code, result.output) assert result.exit_code != 0 - assert ('Ethernet68', '10.10.10.2/24') not in db.cfgdb.get_table('INTERFACE') + assert ERROR_MSG in result.output ''' Tests for IPv6 ''' @@ -84,13 +78,13 @@ def test_add_del_interface_valid_ipv6(self): obj = {'config_db':db.cfgdb} # config int ip add Ethernet72 2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34 - result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet72", "2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34"], obj=obj) + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet72", "2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34"], obj=obj) print(result.exit_code, result.output) assert result.exit_code == 0 assert ('Ethernet72', '2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34') in db.cfgdb.get_table('INTERFACE') # config int ip remove Ethernet72 2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34 - result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet72", "2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34"], obj=obj) + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet72", "2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34"], obj=obj) print(result.exit_code, result.output) assert result.exit_code != 0 assert ('Ethernet72', '2001:1db8:11a3:19d7:1f34:8a2e:17a0:765d/34') not in db.cfgdb.get_table('INTERFACE') @@ -122,34 +116,34 @@ def test_add_del_interface_ipv6_with_leading_zeros(self): runner = CliRunner() obj = {'config_db':db.cfgdb} - # config int ip del Ethernet68 2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34 + # config int ip add Ethernet68 2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34 result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet68", "2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34"], obj=obj) print(result.exit_code, result.output) assert result.exit_code == 0 - assert ('Ethernet68', '2001:db8:11a3:9d7:1f34:8a2e:7a0:765d/34') in db.cfgdb.get_table('INTERFACE') + assert ('Ethernet68', '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34') in db.cfgdb.get_table('INTERFACE') # config int ip remove Ethernet68 2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34 result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet68", "2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34"], obj=obj) print(result.exit_code, result.output) assert result.exit_code != 0 - assert ('Ethernet68', '2001:db8:11a3:9d7:1f34:8a2e:7a0:765d/34') not in db.cfgdb.get_table('INTERFACE') + assert ('Ethernet68', '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d/34') not in db.cfgdb.get_table('INTERFACE') def test_add_del_interface_shortened_ipv6_with_leading_zeros(self): db = Db() runner = CliRunner() obj = {'config_db':db.cfgdb} - # config int ip del Ethernet68 3000::001/64 + # config int ip add Ethernet68 3000::001/64 result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["add"], ["Ethernet68", "3000::001/64"], obj=obj) print(result.exit_code, result.output) assert result.exit_code == 0 - assert ('Ethernet68', '3000::1/64') in db.cfgdb.get_table('INTERFACE') + assert ('Ethernet68', '3000::001/64') in db.cfgdb.get_table('INTERFACE') # config int ip remove Ethernet68 3000::001/64 - result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet68", "3000::001/64"], obj=obj) + result = runner.invoke(config.config.commands["interface"].commands["ip"].commands["remove"], ["Ethernet68", "3000::001/64"], obj=obj) print(result.exit_code, result.output) assert result.exit_code != 0 - assert ('Ethernet68', '3000::1/64') not in db.cfgdb.get_table('INTERFACE') + assert ('Ethernet68', '3000::001/64') not in db.cfgdb.get_table('INTERFACE') @classmethod def teardown_class(cls):