zeek/auxil/broker/tests/btest/scripts/wire_format.py
Patrick Kelley 8fd444092b initial
2025-05-07 15:35:15 -04:00

318 lines
8.9 KiB
Python

import socket
import struct
import sys
import time
import traceback
import uuid
from collections import namedtuple
from enum import IntEnum
from inspect import getframeinfo, stack
class MessageType(IntEnum):
DATA = 1
COMMAND = 2
ROUTING_UPDATE = 3
PING = 4
PONG = 5
HELLO = 6
PROBE = 7
VERSION_SELECT = 8
DROP_CONN = 9
ORIGINATOR_SYN = 10
RESPONDER_SYN_ACK = 11
ORIGINATOR_ACK = 12
# Returns Broker's magic number.
def magic_number():
return 0x5A45454B
# Splits a buffer into head and tail.
def behead(buf, size):
return (buf[:size], buf[size:])
# Returns its argument as varbyte-encoded buffer.
def vb_encode(n):
buf = bytearray()
x = n
while x > 0x7F:
buf.append((x & 0x7F) | 0x80)
x = x >> 7
buf.append(x & 0x7F)
return buf
# Decodes a number from a varbyte-encoded buffer.
def vb_decode(buf):
x = 0
n = 0
low7 = 0
while True:
low7 = buf[0]
buf = buf[1:]
x = x | ((low7 & 0x7F) << (7 * n))
n += 1
if low7 & 0x80 == 0:
return (x, buf)
# Packs a Broker subscription into a buffer.
def pack_subscriptions(subs):
buf = bytearray(vb_encode(len(subs)))
for sub in subs:
buf.extend(vb_encode(len(sub)))
buf.extend(sub.encode())
return buf
# Packs a string to a buffer. The resulting buffer contains the size as
# varbyte-encoded prefix followed by the characters.
def pack_string(x):
buf = vb_encode(len(x))
buf.extend(x.encode())
return buf
# Unpacks a string from a buffer, the buffer must contain the size as
# varbyte-encoded prefix followed by the characters.
def unpack_string(buf):
(str_size, tail) = vb_decode(buf)
(head, remainder) = behead(tail, str_size)
return (head.decode(), remainder)
# Unpacks a Broker subscription.
def unpack_subscriptions(buf):
result = []
(size, remainder) = vb_decode(buf)
print(f"unpack subscription of size {size}")
for i in range(size):
(sub, remainder) = unpack_string(remainder)
result.append(sub)
return (result, remainder)
# -- pack and unpack functions for handshake messages --------------------------
def unpack_hello(buf):
HelloMsg = namedtuple(
"HelloMsg", ["magic", "sender_id", "min_version", "max_version"]
)
(magic, uuid_bytes, vmin, vmax) = struct.unpack("!I16sBB", buf)
return HelloMsg(magic, uuid.UUID(bytes=uuid_bytes), vmin, vmax)
def pack_hello(sender_id, vmin, vmax):
return struct.pack("!I16sBB", magic_number(), sender_id.bytes, vmin, vmax)
def unpack_probe(buf):
ProbeMsg = namedtuple("ProbeMsg", ["magic"])
magic = struct.unpack("!I", buf)[0]
return ProbeMsg(magic)
def pack_probe(msg):
(magic) = msg
return struct.pack("!I", magic)
def pack_version_select(sender_id, version):
return struct.pack("!I16sB", magic_number(), sender_id.bytes, version)
def unpack_version_select(buf):
VersionSelectMsg = namedtuple("VersionSelectMsg", ["magic", "sender_id", "version"])
(magic, sender_id, version) = struct.unpack("!I16sB", buf)
return VersionSelectMsg(magic, uuid.UUID(bytes=sender_id), version)
def pack_drop_conn(sender_id, code, description):
buf = struct.pack("!I16sB", magic_number(), sender_id.bytes, code)
arr = bytearray(buf)
arr.extend(pack_string(description))
return arr
def unpack_drop_conn(buf):
DropConnMsg = namedtuple(
"DropConnMsg", ["magic", "sender_id", "code", "description"]
)
(head, tail) = behead(buf, 21)
(magic, sender_id, code) = struct.unpack("!I16sB", head)
description = unpack_string(tail)
return DropConnMsg(magic, uuid.UUID(bytes=sender_id), code, description)
def pack_originator_ack():
return bytearray()
def unpack_originator_ack(buf):
return ()
def pack_originator_syn(subs):
return pack_subscriptions(subs)
def unpack_originator_syn(buf):
(subscriptions, remainder) = unpack_subscriptions(buf)
if len(remainder) > 0:
raise RuntimeError("unpack_originator_syn: trailing bytes")
OriginatorSynMsg = namedtuple("OriginatorSynMsg", ["subscriptions"])
return OriginatorSynMsg(subscriptions)
def pack_responder_syn_ack(subs):
return pack_subscriptions(subs)
def unpack_responder_syn_ack(buf):
(subscriptions, remainder) = unpack_subscriptions(buf)
if len(remainder) > 0:
raise RuntimeError("unpack_responder_syn_ack: trailing bytes")
ResponderSynAck = namedtuple("ResponderSynAck", ["subscriptions"])
return ResponderSynAck(subscriptions)
# -- utility functions for socket I/O on handshake messages --------------------
def enum_to_str(e):
# For compatibility around Python 3.11, this dictates how to render an enum
# to a string. This changed in 3.11 for some enum types.
return type(e).__name__ + "." + e.name
# Reads a handshake message (phase 1 and phase 2).
def read_hs_msg(fd):
# -1 since we extract the tag right away
msg_len = int.from_bytes(fd.recv(4), byteorder="big", signed=False) - 1
tag = MessageType(fd.recv(1)[0])
tag_str = enum_to_str(tag)
print(f"received a {tag_str} message with {msg_len} bytes")
unpack_tbl = {
MessageType.HELLO: unpack_hello,
MessageType.PROBE: unpack_probe,
MessageType.VERSION_SELECT: unpack_version_select,
MessageType.DROP_CONN: unpack_drop_conn,
MessageType.ORIGINATOR_SYN: unpack_originator_syn,
MessageType.RESPONDER_SYN_ACK: unpack_responder_syn_ack,
MessageType.ORIGINATOR_ACK: unpack_originator_ack,
}
return (tag, unpack_tbl[tag](fd.recv(msg_len)))
# Writes a handshake message (phase 1 and phase 2).
def write_hs_msg(fd, tag, buf):
payload_len = len(buf) + 1
fd.send(payload_len.to_bytes(4, byteorder="big", signed=False))
fd.send(int(tag).to_bytes(1, byteorder="big", signed=False))
fd.send(buf)
tag_str = enum_to_str(tag)
print(f"sent {tag_str} message with {payload_len} bytes")
# -- pack and unpack functions for phase 3 messages ----------------------------
def pack_ping(buf):
return buf
def unpack_ping(buf):
return buf
def pack_pong(buf):
return buf
def unpack_pong(buf):
return buf
# -- utility functions for socket I/O on phase 3 messages ----------------------
# Reads an operation-mode message (phase 3).
def read_op_msg(fd):
# -35 since we extract the two IDs plus tag, TTL and topic len right away
msg_len = int.from_bytes(fd.recv(4), byteorder="big", signed=False) - 37
src = uuid.UUID(bytes=struct.unpack("!16s", fd.recv(16))[0])
dst = uuid.UUID(bytes=struct.unpack("!16s", fd.recv(16))[0])
tag = MessageType(fd.recv(1)[0])
ttl = int.from_bytes(fd.recv(2), byteorder="big", signed=False)
topic_len = int.from_bytes(fd.recv(2), byteorder="big", signed=False)
topic = fd.recv(topic_len).decode()
buf = fd.recv(msg_len - topic_len)
tag_str = enum_to_str(tag)
print(f"received a {tag_str} with a payload of {msg_len} bytes")
unpack_tbl = {
MessageType.PING: unpack_ping,
MessageType.PONG: unpack_pong,
}
NodeMsg = namedtuple(
"NodeMsg", ["sender_id", "receiver_id", "ttl", "topic", "payload"]
)
return (tag, NodeMsg(src, dst, ttl, topic, unpack_tbl[tag](buf)))
def write_op_msg(fd, src, dst, tag, topic, buf):
payload_len = len(buf) + 37 + len(topic)
fd.send(payload_len.to_bytes(4, byteorder="big", signed=False))
fd.send(src.bytes) # sender UUID
fd.send(dst.bytes) # receiver UUID
fd.send(int(tag).to_bytes(1, byteorder="big", signed=False)) # msg type
fd.send((1).to_bytes(2, byteorder="big", signed=False)) # ttl
fd.send(len(topic).to_bytes(2, byteorder="big", signed=False))
fd.send(topic.encode())
fd.send(buf)
tag_str = enum_to_str(tag)
print(f"sent {tag_str} message with a payload of {payload_len} bytes")
# -- minimal testing DSL -------------------------------------------------------
def check_eq(got, want):
caller = getframeinfo(stack()[1][0])
line = caller.lineno
if got != want:
raise RuntimeError(f"line {line}: check failed -> {got} != {want}")
print(f"line {line}: check passed")
def write_done():
with open("done", "w") as f:
f.write("done")
# Tries to connect up to 30 times before giving up.
def test_main(host, port, fn):
connected = False
for i in range(30):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as fd:
fd.connect(("localhost", port))
connected = True
fn(fd)
write_done()
sys.exit()
except Exception as ex:
if not connected:
print(
f"failed to connect to localhost:{port}, try again", file=sys.stderr
)
time.sleep(1)
else:
print(str(ex))
print(traceback.format_exc())
write_done()
sys.exit()