318 lines
8.9 KiB
Python
318 lines
8.9 KiB
Python
import socket
|
|
import struct
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from collections import namedtuple
|
|
from enum import IntEnum
|
|
from inspect import getframeinfo, stack
|
|
|
|
|
|
class MessageType(IntEnum):
|
|
DATA = 1
|
|
COMMAND = 2
|
|
ROUTING_UPDATE = 3
|
|
PING = 4
|
|
PONG = 5
|
|
HELLO = 6
|
|
PROBE = 7
|
|
VERSION_SELECT = 8
|
|
DROP_CONN = 9
|
|
ORIGINATOR_SYN = 10
|
|
RESPONDER_SYN_ACK = 11
|
|
ORIGINATOR_ACK = 12
|
|
|
|
|
|
# Returns Broker's magic number.
|
|
def magic_number():
|
|
return 0x5A45454B
|
|
|
|
|
|
# Splits a buffer into head and tail.
|
|
def behead(buf, size):
|
|
return (buf[:size], buf[size:])
|
|
|
|
|
|
# Returns its argument as varbyte-encoded buffer.
|
|
def vb_encode(n):
|
|
buf = bytearray()
|
|
x = n
|
|
while x > 0x7F:
|
|
buf.append((x & 0x7F) | 0x80)
|
|
x = x >> 7
|
|
buf.append(x & 0x7F)
|
|
return buf
|
|
|
|
|
|
# Decodes a number from a varbyte-encoded buffer.
|
|
def vb_decode(buf):
|
|
x = 0
|
|
n = 0
|
|
low7 = 0
|
|
while True:
|
|
low7 = buf[0]
|
|
buf = buf[1:]
|
|
x = x | ((low7 & 0x7F) << (7 * n))
|
|
n += 1
|
|
if low7 & 0x80 == 0:
|
|
return (x, buf)
|
|
|
|
|
|
# Packs a Broker subscription into a buffer.
|
|
def pack_subscriptions(subs):
|
|
buf = bytearray(vb_encode(len(subs)))
|
|
for sub in subs:
|
|
buf.extend(vb_encode(len(sub)))
|
|
buf.extend(sub.encode())
|
|
return buf
|
|
|
|
|
|
# Packs a string to a buffer. The resulting buffer contains the size as
|
|
# varbyte-encoded prefix followed by the characters.
|
|
def pack_string(x):
|
|
buf = vb_encode(len(x))
|
|
buf.extend(x.encode())
|
|
return buf
|
|
|
|
|
|
# Unpacks a string from a buffer, the buffer must contain the size as
|
|
# varbyte-encoded prefix followed by the characters.
|
|
def unpack_string(buf):
|
|
(str_size, tail) = vb_decode(buf)
|
|
(head, remainder) = behead(tail, str_size)
|
|
return (head.decode(), remainder)
|
|
|
|
|
|
# Unpacks a Broker subscription.
|
|
def unpack_subscriptions(buf):
|
|
result = []
|
|
(size, remainder) = vb_decode(buf)
|
|
print(f"unpack subscription of size {size}")
|
|
for i in range(size):
|
|
(sub, remainder) = unpack_string(remainder)
|
|
result.append(sub)
|
|
return (result, remainder)
|
|
|
|
|
|
# -- pack and unpack functions for handshake messages --------------------------
|
|
|
|
|
|
def unpack_hello(buf):
|
|
HelloMsg = namedtuple(
|
|
"HelloMsg", ["magic", "sender_id", "min_version", "max_version"]
|
|
)
|
|
(magic, uuid_bytes, vmin, vmax) = struct.unpack("!I16sBB", buf)
|
|
return HelloMsg(magic, uuid.UUID(bytes=uuid_bytes), vmin, vmax)
|
|
|
|
|
|
def pack_hello(sender_id, vmin, vmax):
|
|
return struct.pack("!I16sBB", magic_number(), sender_id.bytes, vmin, vmax)
|
|
|
|
|
|
def unpack_probe(buf):
|
|
ProbeMsg = namedtuple("ProbeMsg", ["magic"])
|
|
magic = struct.unpack("!I", buf)[0]
|
|
return ProbeMsg(magic)
|
|
|
|
|
|
def pack_probe(msg):
|
|
(magic) = msg
|
|
return struct.pack("!I", magic)
|
|
|
|
|
|
def pack_version_select(sender_id, version):
|
|
return struct.pack("!I16sB", magic_number(), sender_id.bytes, version)
|
|
|
|
|
|
def unpack_version_select(buf):
|
|
VersionSelectMsg = namedtuple("VersionSelectMsg", ["magic", "sender_id", "version"])
|
|
(magic, sender_id, version) = struct.unpack("!I16sB", buf)
|
|
return VersionSelectMsg(magic, uuid.UUID(bytes=sender_id), version)
|
|
|
|
|
|
def pack_drop_conn(sender_id, code, description):
|
|
buf = struct.pack("!I16sB", magic_number(), sender_id.bytes, code)
|
|
arr = bytearray(buf)
|
|
arr.extend(pack_string(description))
|
|
return arr
|
|
|
|
|
|
def unpack_drop_conn(buf):
|
|
DropConnMsg = namedtuple(
|
|
"DropConnMsg", ["magic", "sender_id", "code", "description"]
|
|
)
|
|
(head, tail) = behead(buf, 21)
|
|
(magic, sender_id, code) = struct.unpack("!I16sB", head)
|
|
description = unpack_string(tail)
|
|
return DropConnMsg(magic, uuid.UUID(bytes=sender_id), code, description)
|
|
|
|
|
|
def pack_originator_ack():
|
|
return bytearray()
|
|
|
|
|
|
def unpack_originator_ack(buf):
|
|
return ()
|
|
|
|
|
|
def pack_originator_syn(subs):
|
|
return pack_subscriptions(subs)
|
|
|
|
|
|
def unpack_originator_syn(buf):
|
|
(subscriptions, remainder) = unpack_subscriptions(buf)
|
|
if len(remainder) > 0:
|
|
raise RuntimeError("unpack_originator_syn: trailing bytes")
|
|
OriginatorSynMsg = namedtuple("OriginatorSynMsg", ["subscriptions"])
|
|
return OriginatorSynMsg(subscriptions)
|
|
|
|
|
|
def pack_responder_syn_ack(subs):
|
|
return pack_subscriptions(subs)
|
|
|
|
|
|
def unpack_responder_syn_ack(buf):
|
|
(subscriptions, remainder) = unpack_subscriptions(buf)
|
|
if len(remainder) > 0:
|
|
raise RuntimeError("unpack_responder_syn_ack: trailing bytes")
|
|
ResponderSynAck = namedtuple("ResponderSynAck", ["subscriptions"])
|
|
return ResponderSynAck(subscriptions)
|
|
|
|
|
|
# -- utility functions for socket I/O on handshake messages --------------------
|
|
|
|
|
|
def enum_to_str(e):
|
|
# For compatibility around Python 3.11, this dictates how to render an enum
|
|
# to a string. This changed in 3.11 for some enum types.
|
|
return type(e).__name__ + "." + e.name
|
|
|
|
|
|
# Reads a handshake message (phase 1 and phase 2).
|
|
def read_hs_msg(fd):
|
|
# -1 since we extract the tag right away
|
|
msg_len = int.from_bytes(fd.recv(4), byteorder="big", signed=False) - 1
|
|
tag = MessageType(fd.recv(1)[0])
|
|
tag_str = enum_to_str(tag)
|
|
print(f"received a {tag_str} message with {msg_len} bytes")
|
|
unpack_tbl = {
|
|
MessageType.HELLO: unpack_hello,
|
|
MessageType.PROBE: unpack_probe,
|
|
MessageType.VERSION_SELECT: unpack_version_select,
|
|
MessageType.DROP_CONN: unpack_drop_conn,
|
|
MessageType.ORIGINATOR_SYN: unpack_originator_syn,
|
|
MessageType.RESPONDER_SYN_ACK: unpack_responder_syn_ack,
|
|
MessageType.ORIGINATOR_ACK: unpack_originator_ack,
|
|
}
|
|
return (tag, unpack_tbl[tag](fd.recv(msg_len)))
|
|
|
|
|
|
# Writes a handshake message (phase 1 and phase 2).
|
|
def write_hs_msg(fd, tag, buf):
|
|
payload_len = len(buf) + 1
|
|
fd.send(payload_len.to_bytes(4, byteorder="big", signed=False))
|
|
fd.send(int(tag).to_bytes(1, byteorder="big", signed=False))
|
|
fd.send(buf)
|
|
tag_str = enum_to_str(tag)
|
|
print(f"sent {tag_str} message with {payload_len} bytes")
|
|
|
|
|
|
# -- pack and unpack functions for phase 3 messages ----------------------------
|
|
|
|
|
|
def pack_ping(buf):
|
|
return buf
|
|
|
|
|
|
def unpack_ping(buf):
|
|
return buf
|
|
|
|
|
|
def pack_pong(buf):
|
|
return buf
|
|
|
|
|
|
def unpack_pong(buf):
|
|
return buf
|
|
|
|
|
|
# -- utility functions for socket I/O on phase 3 messages ----------------------
|
|
|
|
|
|
# Reads an operation-mode message (phase 3).
|
|
def read_op_msg(fd):
|
|
# -35 since we extract the two IDs plus tag, TTL and topic len right away
|
|
msg_len = int.from_bytes(fd.recv(4), byteorder="big", signed=False) - 37
|
|
src = uuid.UUID(bytes=struct.unpack("!16s", fd.recv(16))[0])
|
|
dst = uuid.UUID(bytes=struct.unpack("!16s", fd.recv(16))[0])
|
|
tag = MessageType(fd.recv(1)[0])
|
|
ttl = int.from_bytes(fd.recv(2), byteorder="big", signed=False)
|
|
topic_len = int.from_bytes(fd.recv(2), byteorder="big", signed=False)
|
|
topic = fd.recv(topic_len).decode()
|
|
buf = fd.recv(msg_len - topic_len)
|
|
tag_str = enum_to_str(tag)
|
|
print(f"received a {tag_str} with a payload of {msg_len} bytes")
|
|
unpack_tbl = {
|
|
MessageType.PING: unpack_ping,
|
|
MessageType.PONG: unpack_pong,
|
|
}
|
|
NodeMsg = namedtuple(
|
|
"NodeMsg", ["sender_id", "receiver_id", "ttl", "topic", "payload"]
|
|
)
|
|
return (tag, NodeMsg(src, dst, ttl, topic, unpack_tbl[tag](buf)))
|
|
|
|
|
|
def write_op_msg(fd, src, dst, tag, topic, buf):
|
|
payload_len = len(buf) + 37 + len(topic)
|
|
fd.send(payload_len.to_bytes(4, byteorder="big", signed=False))
|
|
fd.send(src.bytes) # sender UUID
|
|
fd.send(dst.bytes) # receiver UUID
|
|
fd.send(int(tag).to_bytes(1, byteorder="big", signed=False)) # msg type
|
|
fd.send((1).to_bytes(2, byteorder="big", signed=False)) # ttl
|
|
fd.send(len(topic).to_bytes(2, byteorder="big", signed=False))
|
|
fd.send(topic.encode())
|
|
fd.send(buf)
|
|
tag_str = enum_to_str(tag)
|
|
print(f"sent {tag_str} message with a payload of {payload_len} bytes")
|
|
|
|
|
|
# -- minimal testing DSL -------------------------------------------------------
|
|
|
|
|
|
def check_eq(got, want):
|
|
caller = getframeinfo(stack()[1][0])
|
|
line = caller.lineno
|
|
if got != want:
|
|
raise RuntimeError(f"line {line}: check failed -> {got} != {want}")
|
|
print(f"line {line}: check passed")
|
|
|
|
|
|
def write_done():
|
|
with open("done", "w") as f:
|
|
f.write("done")
|
|
|
|
|
|
# Tries to connect up to 30 times before giving up.
|
|
def test_main(host, port, fn):
|
|
connected = False
|
|
for i in range(30):
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as fd:
|
|
fd.connect(("localhost", port))
|
|
connected = True
|
|
fn(fd)
|
|
write_done()
|
|
sys.exit()
|
|
except Exception as ex:
|
|
if not connected:
|
|
print(
|
|
f"failed to connect to localhost:{port}, try again", file=sys.stderr
|
|
)
|
|
time.sleep(1)
|
|
else:
|
|
print(str(ex))
|
|
print(traceback.format_exc())
|
|
write_done()
|
|
sys.exit()
|