diff --git a/pkg/tcpip/nftables/nftables.go b/pkg/tcpip/nftables/nftables.go index 154d2271fa..52177a46f3 100644 --- a/pkg/tcpip/nftables/nftables.go +++ b/pkg/tcpip/nftables/nftables.go @@ -1744,6 +1744,68 @@ func (op metaLoad) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Ru copy(dst, target) } +// metaSet is an operation that sets specific meta data into to the value in a +// register. +// Note: meta operations are not supported for the verdict register. +// TODO(b/345684870): Support setting more meta fields for Meta Set. +type metaSet struct { + key metaKey // Meta key specifying what data to set. + sreg uint8 // Number of the source register. +} + +// checkMetaKeySetCompatable checks that the meta key is valid for meta set. +func checkMetaKeySetCompatable(key metaKey) error { + switch key { + // Supported meta keys. + case linux.NFT_META_PKTTYPE: + return nil + // Should be supported but not yet implemented. + case linux.NFT_META_MARK, linux.NFT_META_PRIORITY, + linux.NFT_META_NFTRACE, linux.NFT_META_SECMARK: + return fmt.Errorf("meta key %v is not supported for meta set", key) + // All other keys cannot be used with meta set (strictly for loading). + default: + return fmt.Errorf("meta key %v is not compatible with meta set", key) + } +} + +// newMetaSet creates a new metaSet operation. +func newMetaSet(key metaKey, sreg uint8) (*metaSet, error) { + if isVerdictRegister(sreg) { + return nil, fmt.Errorf("meta set operation cannot use verdict register as destination") + } + if err := validateMetaKey(key); err != nil { + return nil, err + } + if err := checkMetaKeySetCompatable(key); err != nil { + return nil, err + } + if metaDataLengths[key] > 4 && !is16ByteRegister(sreg) { + return nil, fmt.Errorf("meta load operation cannot use 4-byte register as destination for key %s", key) + } + + return &metaSet{key: key, sreg: sreg}, nil +} + +// evaluate for metaSet sets specific meta data to the value in the source +// register. +func (op metaSet) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Rule) { + // Gets the data from the source register. + src := getRegisterBuffer(regs, op.sreg)[:metaDataLengths[op.key]] + + // Sets the meta data of the appropriate field. + switch op.key { + // Only Packet Type is supported for now. + case linux.NFT_META_PKTTYPE: + pkt.PktType = tcpip.PacketType(src[0]) + return + } + + // Breaks if could not set the meta data. + regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)} + return +} + // // Register and Register-Related Implementations. // Note: Registers are represented by type uint8 for the register number. diff --git a/pkg/tcpip/nftables/nftables_test.go b/pkg/tcpip/nftables/nftables_test.go index b695c51e48..23ebc8eb65 100644 --- a/pkg/tcpip/nftables/nftables_test.go +++ b/pkg/tcpip/nftables/nftables_test.go @@ -2103,53 +2103,8 @@ func TestEvaluatePayloadSet(t *testing.T) { } } - // Compares checksums first for resulting and expected packet. - if test.outPkt.NetworkProtocolNumber != test.pkt.NetworkProtocolNumber { - t.Fatalf("expected network protocol number %d for resulting packet, got %d", test.outPkt.NetworkProtocolNumber, test.pkt.NetworkProtocolNumber) - } - if test.pkt.NetworkHeader().View() != nil && test.outPkt.Network().Checksum() != test.pkt.Network().Checksum() { - t.Fatalf("expected network checksum %d for resulting packet, got %d", test.outPkt.Network().Checksum(), test.pkt.Network().Checksum()) - } - if test.pkt.TransportProtocolNumber != test.outPkt.TransportProtocolNumber { - t.Fatalf("expected transport protocol number %d for resulting packet, got %d", test.outPkt.TransportProtocolNumber, test.pkt.TransportProtocolNumber) - } - if test.pkt.TransportProtocolNumber != 0 { - var transport header.Transport - var transportOut header.Transport - switch tBytes, tOutBytes := test.pkt.TransportHeader().Slice(), - test.outPkt.TransportHeader().Slice(); test.pkt.TransportProtocolNumber { - case header.TCPProtocolNumber: - transport = header.TCP(tBytes) - transportOut = header.TCP(tOutBytes) - case header.UDPProtocolNumber: - transport = header.UDP(tBytes) - transportOut = header.UDP(tOutBytes) - case header.ICMPv4ProtocolNumber: - transport = header.ICMPv4(tBytes) - transportOut = header.ICMPv4(tOutBytes) - case header.ICMPv6ProtocolNumber: - transport = header.ICMPv6(tBytes) - transportOut = header.ICMPv6(tOutBytes) - case header.IGMPProtocolNumber: - transport = header.IGMP(tBytes) - transportOut = header.IGMP(tOutBytes) - } - if transport != nil && transport.Checksum() != transportOut.Checksum() { - t.Fatalf("expected transport checksum %d for resulting packet, got %d", transport.Checksum(), transportOut.Checksum()) - } - } - - // Compares raw packet data in bytes for resulting and expected packet. - actual := test.pkt.AsSlices() - expected := test.outPkt.AsSlices() - if len(actual) != len(expected) { - t.Fatalf("expected %d slices of data for the resulting packet, got %d", len(expected), len(actual)) - } - for i := range actual { - if !slices.Equal(actual[i], expected[i]) { - t.Fatalf("packet data does not match expected packet data (for slice %d)", i) - } - } + // Checks if the packet are equal. + checkPacketEquality(t, test.outPkt, test.pkt) }) } } @@ -3018,6 +2973,89 @@ func TestEvaluateMetaLoad(t *testing.T) { } } +// TestEvaluateMetaSet tests that the Meta Set operation correctly sets specific +// packet meta data to the value in the source register. +func TestEvaluateMetaSet(t *testing.T) { + // Packet type to set anc test for. + testPktType := tcpip.PacketMulticast + for _, test := range []struct { + tname string + pkt *stack.PacketBuffer + outPkt *stack.PacketBuffer + op1 operation // Immediate operation to load source register. + op2 operation // Meta set operation to test. + }{ + // cmd: nft --debug=netlink add rule ip tab ch meta pkttype set 34 + { + tname: "meta set pkttype 4-byte reg test", + pkt: makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()), + outPkt: func() *stack.PacketBuffer { + pkt := makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()) + pkt.PktType = testPktType + return pkt + }(), + op1: mustCreateImmediate(t, linux.NFT_REG32_06, newBytesData([]byte{uint8(testPktType)})), + op2: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG32_06), + }, + { + tname: "meta set pkttype 16-byte reg test", + pkt: makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()), + outPkt: func() *stack.PacketBuffer { + pkt := makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()) + pkt.PktType = testPktType + return pkt + }(), + op1: mustCreateImmediate(t, linux.NFT_REG_3, newBytesData([]byte{uint8(testPktType)})), + op2: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG_3), + }, + } { + t.Run(test.tname, func(t *testing.T) { + // Sets up an NFTables object with a single table, chain, and rule. + nf := newNFTablesStd() + tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + if err != nil { + t.Fatalf("unexpected error for AddTable: %v", err) + } + bc, err := tab.AddChain("base_chain", nil, "test chain", false) + if err != nil { + t.Fatalf("unexpected error for AddChain: %v", err) + } + bc.SetBaseChainInfo(arbitraryInfoPolicyAccept) + rule := &Rule{} + + // Adds testing operations. + if test.op1 != nil { + rule.addOperation(test.op1) + } + if test.op2 != nil { + rule.addOperation(test.op2) + } + + // Adds drop operation, to be final verdict if evaluation is successful. + rule.addOperation(mustCreateImmediate(t, linux.NFT_REG_VERDICT, newVerdictData(Verdict{Code: VC(linux.NF_DROP)}))) + + // Registers the rule to the base chain. + if err := bc.RegisterRule(rule, -1); err != nil { + t.Fatalf("unexpected error for RegisterRule: %v", err) + } + + // Runs evaluation. + v, err := nf.EvaluateHook(arbitraryFamily, arbitraryHook, test.pkt) + if err != nil { + t.Fatalf("unexpected error for EvaluateHook: %v", err) + } + + // Evaluation should be successful and result in Drop verdict. + if v.Code != VC(linux.NF_DROP) { + t.Fatalf("expected verdict Drop for successful evaluation, got %v", v) + } + + // Checks if the packet are equal. + checkPacketEquality(t, test.outPkt, test.pkt) + }) + } +} + // TestLoopCheckOnRegisterAndUnregister tests the loop checking and accompanying // logic on registering and unregistering rules. func TestLoopCheckOnRegisterAndUnregister(t *testing.T) { @@ -3590,6 +3628,64 @@ func TestMaxNestedJumps(t *testing.T) { } } +// checkPacketEquality checks that the given packets are equal for all fields +// and data relevant to our testing. This is not an exhaustive check. +func checkPacketEquality(t *testing.T, expected, actual *stack.PacketBuffer) { + if expected.PktType != actual.PktType { + t.Fatalf("expected packet type %d for resulting packet, got %d", int(expected.PktType), int(actual.PktType)) + } + + // Compares checksums first for the expected and actual packet. + if expected.NetworkProtocolNumber != actual.NetworkProtocolNumber { + t.Fatalf("expected network protocol number %d for resulting packet, got %d", expected.NetworkProtocolNumber, actual.NetworkProtocolNumber) + } + if actualHasNetwork, expectedHasNetwork := actual.NetworkHeader().View() != nil, expected.NetworkHeader().View() != nil; actualHasNetwork != expectedHasNetwork { + t.Fatalf("expected network header is present to be %t for resulting packet, got %v", actualHasNetwork, expectedHasNetwork) + } + if actual.NetworkHeader().View() != nil && expected.Network().Checksum() != actual.Network().Checksum() { + t.Fatalf("expected network checksum %d for resulting packet, got %d", expected.Network().Checksum(), actual.Network().Checksum()) + } + if actual.TransportProtocolNumber != expected.TransportProtocolNumber { + t.Fatalf("expected transport protocol number %d for resulting packet, got %d", expected.TransportProtocolNumber, actual.TransportProtocolNumber) + } + if actual.TransportProtocolNumber != 0 { + var transport header.Transport + var transportExpected header.Transport + switch tBytes, tOutBytes := actual.TransportHeader().Slice(), expected.TransportHeader().Slice(); actual.TransportProtocolNumber { + case header.TCPProtocolNumber: + transport = header.TCP(tBytes) + transportExpected = header.TCP(tOutBytes) + case header.UDPProtocolNumber: + transport = header.UDP(tBytes) + transportExpected = header.UDP(tOutBytes) + case header.ICMPv4ProtocolNumber: + transport = header.ICMPv4(tBytes) + transportExpected = header.ICMPv4(tOutBytes) + case header.ICMPv6ProtocolNumber: + transport = header.ICMPv6(tBytes) + transportExpected = header.ICMPv6(tOutBytes) + case header.IGMPProtocolNumber: + transport = header.IGMP(tBytes) + transportExpected = header.IGMP(tOutBytes) + } + if transport != nil && transport.Checksum() != transportExpected.Checksum() { + t.Fatalf("expected transport checksum %d for resulting packet, got %d", transport.Checksum(), transportExpected.Checksum()) + } + } + + // Compares raw packet data in bytes for resulting and expected packet. + actualSlices := actual.AsSlices() + expectedSlices := expected.AsSlices() + if len(actualSlices) != len(expectedSlices) { + t.Fatalf("expected %d slices of data for the resulting packet, got %d", len(expectedSlices), len(actualSlices)) + } + for i := range actualSlices { + if !slices.Equal(actualSlices[i], expectedSlices[i]) { + t.Fatalf("packet data does not match expected packet data (for slice %d)", i) + } + } +} + // numToBE converts an n-byte int to Big Endian where n is in [1, 8]. // Assumes the given number can be represented in n bytes. func numToBE(v int, n int) []byte { @@ -3710,3 +3806,12 @@ func mustCreateMetaLoad(t *testing.T, key metaKey, dreg uint8) *metaLoad { } return mtload } + +// mustCreateMetaSet wraps the newMetaSet function for brevity. +func mustCreateMetaSet(t *testing.T, key metaKey, sreg uint8) *metaSet { + mtset, err := newMetaSet(key, sreg) + if err != nil { + t.Fatalf("failed to create meta set: %v", err) + } + return mtset +} diff --git a/pkg/tcpip/nftables/nftinterp.go b/pkg/tcpip/nftables/nftinterp.go index 80fdcc35ad..dbcd023479 100644 --- a/pkg/tcpip/nftables/nftinterp.go +++ b/pkg/tcpip/nftables/nftinterp.go @@ -158,7 +158,13 @@ func InterpretOperation(line string, lnIdx int) (operation, error) { case "byteorder": return InterpretByteorder(line, lnIdx) case "meta": - return InterpretMetaLoad(line, lnIdx) + switch tokens[2] { + case "load": + return InterpretMetaLoad(line, lnIdx) + case "set": + return InterpretMetaSet(line, lnIdx) + } + return nil, &SyntaxError{lnIdx, 2, fmt.Sprintf("unrecognized operation type: meta %s", tokens[2])} default: return nil, &SyntaxError{lnIdx, 1, fmt.Sprintf("unrecognized operation type: %s", tokens[1])} } @@ -883,6 +889,69 @@ func InterpretMetaLoad(line string, lnIdx int) (operation, error) { return mtLoad, nil } +// InterpretMetaSet creates a new MetaSet operation from the given string. +func InterpretMetaSet(line string, lnIdx int) (operation, error) { + tokens := strings.Fields(line) + + // Requires exactly 8 tokens: + // "[", "meta", "set", meta key, "with", "reg", register index, "]". + if len(tokens) != 8 { + return nil, &SyntaxError{lnIdx, 0, fmt.Sprintf("incorrect number of tokens for meta operation, should be exactly 8, got %d", len(tokens))} + } + + if err := checkOperationBrackets(tokens, lnIdx); err != nil { + return nil, err + } + + tkIdx := 1 + + // First token should be "meta". + if err := consumeToken("meta", tokens, lnIdx, tkIdx); err != nil { + return nil, err + } + tkIdx++ + + // Second token should be "set". + if err := consumeToken("set", tokens, lnIdx, tkIdx); err != nil { + return nil, err + } + tkIdx++ + + // Third token should be the meta key. + key, err := parseMetaKey(tokens[tkIdx], lnIdx, tkIdx) + if err != nil { + return nil, err + } + tkIdx++ + + // Fourth token should be "with". + if err := consumeToken("with", tokens, lnIdx, tkIdx); err != nil { + return nil, err + } + tkIdx++ + + // Fifth token should be "reg". + if err := consumeToken("reg", tokens, lnIdx, tkIdx); err != nil { + return nil, err + } + tkIdx++ + + // Sixth token should be the uint8 representing the register index. + reg, err := parseRegister(tokens[tkIdx], lnIdx, tkIdx) + if err != nil { + return nil, err + } + tkIdx++ + + // Create the operation with the specified arguments. + mtSet, err := newMetaSet(key, reg) + if err != nil { + return nil, &LogicError{lnIdx, tkIdx, err} + } + + return mtSet, nil +} + // // Interpreter Helper Functions. // diff --git a/pkg/tcpip/nftables/nftinterp_test.go b/pkg/tcpip/nftables/nftinterp_test.go index f12660df4c..eee556f080 100644 --- a/pkg/tcpip/nftables/nftinterp_test.go +++ b/pkg/tcpip/nftables/nftinterp_test.go @@ -1052,6 +1052,42 @@ func checkMetaLoadOp(tname string, expected operation, actual operation) error { return nil } +// TestInterpretMetaSetOps tests interpretation of meta set operations. +func TestInterpretMetaSetOps(t *testing.T) { + for _, test := range []interpretOperationTestAction{ + // cmd: nft --debug=netlink add rule ip tab ch meta pkttype set 34 + { + tname: "meta set pkttype 4-byte reg test", + opStr: "[ meta set pkttype with reg 14 ]", + expected: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG32_06), + }, + { + tname: "meta set pkttype 16-byte reg test", + opStr: "[ meta set pkttype with reg 3 ]", + expected: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG_3), + }, + } { + t.Run(test.tname, func(t *testing.T) { checkOp(t, test, checkMetaSetOp) }) + } +} + +// checkMetaSetOp checks that the given operation is a meta set operation and +// that it matches the expected meta set operation. +func checkMetaSetOp(tname string, expected operation, actual operation) error { + expectedMtSet := expected.(*metaSet) + mtSet, ok := actual.(*metaSet) + if !ok { + return fmt.Errorf("expected operation type to be MetaLoad for %s, got %T", tname, actual) + } + if mtSet.key != expectedMtSet.key { + return fmt.Errorf("expected meta key to be %v for %s, got %v", expectedMtSet.key, tname, mtSet.key) + } + if mtSet.sreg != expectedMtSet.sreg { + return fmt.Errorf("expected destination register to be %d for %s, got %d", expectedMtSet.sreg, tname, mtSet.sreg) + } + return nil +} + // TestInterpretRule tests the interpretation of basic and general rules as a // list of operations. func TestInterpretRule(t *testing.T) {