Skip to content

Commit

Permalink
Implement Meta Set operation (parsing, interpretation, evaluation, te…
Browse files Browse the repository at this point in the history
…sts).

The set of meta fields for meta set is much more limited than meta load.
Currently supports setting a limited set of meta fields: pkttype.
It should also be able to support: mark, priority, nftrace, and secmark,
but these have yet to be implemented.

PiperOrigin-RevId: 674426635
  • Loading branch information
Jayden Nyamiaka authored and gvisor-bot committed Sep 13, 2024
1 parent 485a520 commit b2340af
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 48 deletions.
62 changes: 62 additions & 0 deletions pkg/tcpip/nftables/nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,68 @@ func (op metaLoad) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Ru
copy(dst, target)
}

// metaSet is an operation that sets specific meta data into to the value in a
// register.
// Note: meta operations are not supported for the verdict register.
// TODO(b/345684870): Support setting more meta fields for Meta Set.
type metaSet struct {
key metaKey // Meta key specifying what data to set.
sreg uint8 // Number of the source register.
}

// checkMetaKeySetCompatable checks that the meta key is valid for meta set.
func checkMetaKeySetCompatable(key metaKey) error {
switch key {
// Supported meta keys.
case linux.NFT_META_PKTTYPE:
return nil
// Should be supported but not yet implemented.
case linux.NFT_META_MARK, linux.NFT_META_PRIORITY,
linux.NFT_META_NFTRACE, linux.NFT_META_SECMARK:
return fmt.Errorf("meta key %v is not supported for meta set", key)
// All other keys cannot be used with meta set (strictly for loading).
default:
return fmt.Errorf("meta key %v is not compatible with meta set", key)
}
}

// newMetaSet creates a new metaSet operation.
func newMetaSet(key metaKey, sreg uint8) (*metaSet, error) {
if isVerdictRegister(sreg) {
return nil, fmt.Errorf("meta set operation cannot use verdict register as destination")
}
if err := validateMetaKey(key); err != nil {
return nil, err
}
if err := checkMetaKeySetCompatable(key); err != nil {
return nil, err
}
if metaDataLengths[key] > 4 && !is16ByteRegister(sreg) {
return nil, fmt.Errorf("meta load operation cannot use 4-byte register as destination for key %s", key)
}

return &metaSet{key: key, sreg: sreg}, nil
}

// evaluate for metaSet sets specific meta data to the value in the source
// register.
func (op metaSet) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Rule) {
// Gets the data from the source register.
src := getRegisterBuffer(regs, op.sreg)[:metaDataLengths[op.key]]

// Sets the meta data of the appropriate field.
switch op.key {
// Only Packet Type is supported for now.
case linux.NFT_META_PKTTYPE:
pkt.PktType = tcpip.PacketType(src[0])
return
}

// Breaks if could not set the meta data.
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
return
}

//
// Register and Register-Related Implementations.
// Note: Registers are represented by type uint8 for the register number.
Expand Down
199 changes: 152 additions & 47 deletions pkg/tcpip/nftables/nftables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2103,53 +2103,8 @@ func TestEvaluatePayloadSet(t *testing.T) {
}
}

// Compares checksums first for resulting and expected packet.
if test.outPkt.NetworkProtocolNumber != test.pkt.NetworkProtocolNumber {
t.Fatalf("expected network protocol number %d for resulting packet, got %d", test.outPkt.NetworkProtocolNumber, test.pkt.NetworkProtocolNumber)
}
if test.pkt.NetworkHeader().View() != nil && test.outPkt.Network().Checksum() != test.pkt.Network().Checksum() {
t.Fatalf("expected network checksum %d for resulting packet, got %d", test.outPkt.Network().Checksum(), test.pkt.Network().Checksum())
}
if test.pkt.TransportProtocolNumber != test.outPkt.TransportProtocolNumber {
t.Fatalf("expected transport protocol number %d for resulting packet, got %d", test.outPkt.TransportProtocolNumber, test.pkt.TransportProtocolNumber)
}
if test.pkt.TransportProtocolNumber != 0 {
var transport header.Transport
var transportOut header.Transport
switch tBytes, tOutBytes := test.pkt.TransportHeader().Slice(),
test.outPkt.TransportHeader().Slice(); test.pkt.TransportProtocolNumber {
case header.TCPProtocolNumber:
transport = header.TCP(tBytes)
transportOut = header.TCP(tOutBytes)
case header.UDPProtocolNumber:
transport = header.UDP(tBytes)
transportOut = header.UDP(tOutBytes)
case header.ICMPv4ProtocolNumber:
transport = header.ICMPv4(tBytes)
transportOut = header.ICMPv4(tOutBytes)
case header.ICMPv6ProtocolNumber:
transport = header.ICMPv6(tBytes)
transportOut = header.ICMPv6(tOutBytes)
case header.IGMPProtocolNumber:
transport = header.IGMP(tBytes)
transportOut = header.IGMP(tOutBytes)
}
if transport != nil && transport.Checksum() != transportOut.Checksum() {
t.Fatalf("expected transport checksum %d for resulting packet, got %d", transport.Checksum(), transportOut.Checksum())
}
}

// Compares raw packet data in bytes for resulting and expected packet.
actual := test.pkt.AsSlices()
expected := test.outPkt.AsSlices()
if len(actual) != len(expected) {
t.Fatalf("expected %d slices of data for the resulting packet, got %d", len(expected), len(actual))
}
for i := range actual {
if !slices.Equal(actual[i], expected[i]) {
t.Fatalf("packet data does not match expected packet data (for slice %d)", i)
}
}
// Checks if the packet are equal.
checkPacketEquality(t, test.outPkt, test.pkt)
})
}
}
Expand Down Expand Up @@ -3018,6 +2973,89 @@ func TestEvaluateMetaLoad(t *testing.T) {
}
}

// TestEvaluateMetaSet tests that the Meta Set operation correctly sets specific
// packet meta data to the value in the source register.
func TestEvaluateMetaSet(t *testing.T) {
// Packet type to set anc test for.
testPktType := tcpip.PacketMulticast
for _, test := range []struct {
tname string
pkt *stack.PacketBuffer
outPkt *stack.PacketBuffer
op1 operation // Immediate operation to load source register.
op2 operation // Meta set operation to test.
}{
// cmd: nft --debug=netlink add rule ip tab ch meta pkttype set 34
{
tname: "meta set pkttype 4-byte reg test",
pkt: makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()),
outPkt: func() *stack.PacketBuffer {
pkt := makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields())
pkt.PktType = testPktType
return pkt
}(),
op1: mustCreateImmediate(t, linux.NFT_REG32_06, newBytesData([]byte{uint8(testPktType)})),
op2: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG32_06),
},
{
tname: "meta set pkttype 16-byte reg test",
pkt: makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields()),
outPkt: func() *stack.PacketBuffer {
pkt := makeIPv4Packet(header.IPv4MinimumSize, arbitraryIPv4Fields())
pkt.PktType = testPktType
return pkt
}(),
op1: mustCreateImmediate(t, linux.NFT_REG_3, newBytesData([]byte{uint8(testPktType)})),
op2: mustCreateMetaSet(t, linux.NFT_META_PKTTYPE, linux.NFT_REG_3),
},
} {
t.Run(test.tname, func(t *testing.T) {
// Sets up an NFTables object with a single table, chain, and rule.
nf := newNFTablesStd()
tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false)
if err != nil {
t.Fatalf("unexpected error for AddTable: %v", err)
}
bc, err := tab.AddChain("base_chain", nil, "test chain", false)
if err != nil {
t.Fatalf("unexpected error for AddChain: %v", err)
}
bc.SetBaseChainInfo(arbitraryInfoPolicyAccept)
rule := &Rule{}

// Adds testing operations.
if test.op1 != nil {
rule.addOperation(test.op1)
}
if test.op2 != nil {
rule.addOperation(test.op2)
}

// Adds drop operation, to be final verdict if evaluation is successful.
rule.addOperation(mustCreateImmediate(t, linux.NFT_REG_VERDICT, newVerdictData(Verdict{Code: VC(linux.NF_DROP)})))

// Registers the rule to the base chain.
if err := bc.RegisterRule(rule, -1); err != nil {
t.Fatalf("unexpected error for RegisterRule: %v", err)
}

// Runs evaluation.
v, err := nf.EvaluateHook(arbitraryFamily, arbitraryHook, test.pkt)
if err != nil {
t.Fatalf("unexpected error for EvaluateHook: %v", err)
}

// Evaluation should be successful and result in Drop verdict.
if v.Code != VC(linux.NF_DROP) {
t.Fatalf("expected verdict Drop for successful evaluation, got %v", v)
}

// Checks if the packet are equal.
checkPacketEquality(t, test.outPkt, test.pkt)
})
}
}

// TestLoopCheckOnRegisterAndUnregister tests the loop checking and accompanying
// logic on registering and unregistering rules.
func TestLoopCheckOnRegisterAndUnregister(t *testing.T) {
Expand Down Expand Up @@ -3590,6 +3628,64 @@ func TestMaxNestedJumps(t *testing.T) {
}
}

// checkPacketEquality checks that the given packets are equal for all fields
// and data relevant to our testing. This is not an exhaustive check.
func checkPacketEquality(t *testing.T, expected, actual *stack.PacketBuffer) {
if expected.PktType != actual.PktType {
t.Fatalf("expected packet type %d for resulting packet, got %d", int(expected.PktType), int(actual.PktType))
}

// Compares checksums first for the expected and actual packet.
if expected.NetworkProtocolNumber != actual.NetworkProtocolNumber {
t.Fatalf("expected network protocol number %d for resulting packet, got %d", expected.NetworkProtocolNumber, actual.NetworkProtocolNumber)
}
if actualHasNetwork, expectedHasNetwork := actual.NetworkHeader().View() != nil, expected.NetworkHeader().View() != nil; actualHasNetwork != expectedHasNetwork {
t.Fatalf("expected network header is present to be %t for resulting packet, got %v", actualHasNetwork, expectedHasNetwork)
}
if actual.NetworkHeader().View() != nil && expected.Network().Checksum() != actual.Network().Checksum() {
t.Fatalf("expected network checksum %d for resulting packet, got %d", expected.Network().Checksum(), actual.Network().Checksum())
}
if actual.TransportProtocolNumber != expected.TransportProtocolNumber {
t.Fatalf("expected transport protocol number %d for resulting packet, got %d", expected.TransportProtocolNumber, actual.TransportProtocolNumber)
}
if actual.TransportProtocolNumber != 0 {
var transport header.Transport
var transportExpected header.Transport
switch tBytes, tOutBytes := actual.TransportHeader().Slice(), expected.TransportHeader().Slice(); actual.TransportProtocolNumber {
case header.TCPProtocolNumber:
transport = header.TCP(tBytes)
transportExpected = header.TCP(tOutBytes)
case header.UDPProtocolNumber:
transport = header.UDP(tBytes)
transportExpected = header.UDP(tOutBytes)
case header.ICMPv4ProtocolNumber:
transport = header.ICMPv4(tBytes)
transportExpected = header.ICMPv4(tOutBytes)
case header.ICMPv6ProtocolNumber:
transport = header.ICMPv6(tBytes)
transportExpected = header.ICMPv6(tOutBytes)
case header.IGMPProtocolNumber:
transport = header.IGMP(tBytes)
transportExpected = header.IGMP(tOutBytes)
}
if transport != nil && transport.Checksum() != transportExpected.Checksum() {
t.Fatalf("expected transport checksum %d for resulting packet, got %d", transport.Checksum(), transportExpected.Checksum())
}
}

// Compares raw packet data in bytes for resulting and expected packet.
actualSlices := actual.AsSlices()
expectedSlices := expected.AsSlices()
if len(actualSlices) != len(expectedSlices) {
t.Fatalf("expected %d slices of data for the resulting packet, got %d", len(expectedSlices), len(actualSlices))
}
for i := range actualSlices {
if !slices.Equal(actualSlices[i], expectedSlices[i]) {
t.Fatalf("packet data does not match expected packet data (for slice %d)", i)
}
}
}

// numToBE converts an n-byte int to Big Endian where n is in [1, 8].
// Assumes the given number can be represented in n bytes.
func numToBE(v int, n int) []byte {
Expand Down Expand Up @@ -3710,3 +3806,12 @@ func mustCreateMetaLoad(t *testing.T, key metaKey, dreg uint8) *metaLoad {
}
return mtload
}

// mustCreateMetaSet wraps the newMetaSet function for brevity.
func mustCreateMetaSet(t *testing.T, key metaKey, sreg uint8) *metaSet {
mtset, err := newMetaSet(key, sreg)
if err != nil {
t.Fatalf("failed to create meta set: %v", err)
}
return mtset
}
71 changes: 70 additions & 1 deletion pkg/tcpip/nftables/nftinterp.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ func InterpretOperation(line string, lnIdx int) (operation, error) {
case "byteorder":
return InterpretByteorder(line, lnIdx)
case "meta":
return InterpretMetaLoad(line, lnIdx)
switch tokens[2] {
case "load":
return InterpretMetaLoad(line, lnIdx)
case "set":
return InterpretMetaSet(line, lnIdx)
}
return nil, &SyntaxError{lnIdx, 2, fmt.Sprintf("unrecognized operation type: meta %s", tokens[2])}
default:
return nil, &SyntaxError{lnIdx, 1, fmt.Sprintf("unrecognized operation type: %s", tokens[1])}
}
Expand Down Expand Up @@ -883,6 +889,69 @@ func InterpretMetaLoad(line string, lnIdx int) (operation, error) {
return mtLoad, nil
}

// InterpretMetaSet creates a new MetaSet operation from the given string.
func InterpretMetaSet(line string, lnIdx int) (operation, error) {
tokens := strings.Fields(line)

// Requires exactly 8 tokens:
// "[", "meta", "set", meta key, "with", "reg", register index, "]".
if len(tokens) != 8 {
return nil, &SyntaxError{lnIdx, 0, fmt.Sprintf("incorrect number of tokens for meta operation, should be exactly 8, got %d", len(tokens))}
}

if err := checkOperationBrackets(tokens, lnIdx); err != nil {
return nil, err
}

tkIdx := 1

// First token should be "meta".
if err := consumeToken("meta", tokens, lnIdx, tkIdx); err != nil {
return nil, err
}
tkIdx++

// Second token should be "set".
if err := consumeToken("set", tokens, lnIdx, tkIdx); err != nil {
return nil, err
}
tkIdx++

// Third token should be the meta key.
key, err := parseMetaKey(tokens[tkIdx], lnIdx, tkIdx)
if err != nil {
return nil, err
}
tkIdx++

// Fourth token should be "with".
if err := consumeToken("with", tokens, lnIdx, tkIdx); err != nil {
return nil, err
}
tkIdx++

// Fifth token should be "reg".
if err := consumeToken("reg", tokens, lnIdx, tkIdx); err != nil {
return nil, err
}
tkIdx++

// Sixth token should be the uint8 representing the register index.
reg, err := parseRegister(tokens[tkIdx], lnIdx, tkIdx)
if err != nil {
return nil, err
}
tkIdx++

// Create the operation with the specified arguments.
mtSet, err := newMetaSet(key, reg)
if err != nil {
return nil, &LogicError{lnIdx, tkIdx, err}
}

return mtSet, nil
}

//
// Interpreter Helper Functions.
//
Expand Down
Loading

0 comments on commit b2340af

Please # to comment.