libipc/src/switch.zig

333 lines
12 KiB
Zig

const std = @import("std");
const testing = std.testing;
const fmt = std.fmt;
const net = std.net;
const ipc = @import("./main.zig");
const Message = ipc.Message;
const CBEventType = ipc.CBEvent.Type;
const Allocator = std.mem.Allocator;
const util = @import("./util.zig");
const print_eq = util.print_eq;
const log = std.log.scoped(.libipc_switch);
const Event = ipc.Event;
/// SwitchDB: store relations between clients and services.
///
/// A protocol service, such as TPCd can handle "external" communications (TCP in this case)
/// meaning that a client can connect to this service through a canal that isn't a simple
/// libipc UNIX socket, and this client is then connected to a local service.
/// OTOH, a local client can ask TCPd to establish a connection to a remote service.
/// In both cases, at least one of the connection isn't libipc-based and should be
/// handled in a specific way that only TPCd (or another protocol service) can.
///
/// TCPd marks both file descriptors as "related" (add_switch) so libipc can automatically
/// handle messages between the client and the service. Any input from one end will be sent
/// to the other.
/// TCPd registers functions to handle specific input and output operations from and to the
/// remote connection (set_callbacks).
///
/// At any point, TCPd can safely close a connection and remote it from the SwitchDB (nuke),
/// resulting in the removal of both the connection's FD and its related FD (both the client
/// and the service connections are removed).
///
/// Currently, libipc automatically closes both the client and its service when an error occurs.
pub const SwitchDB = struct {
const Self = @This();
db: std.AutoArrayHashMap(i32, ManagedConnection),
pub fn init(allocator: Allocator) Self {
return Self{
.db = std.AutoArrayHashMap(i32, ManagedConnection).init(allocator),
};
}
pub fn deinit(self: *Self) void {
self.db.deinit();
}
pub fn format(self: Self, comptime _: []const u8, _: fmt.FormatOptions, out_stream: anytype) !void {
for (self.db.keys(), 0..) |k, i| {
try fmt.format(out_stream, "({},{})", .{ k, self.db.values()[i].dest });
}
}
pub fn add_switch(self: *Self, fd1: i32, fd2: i32) !void {
try self.db.put(fd1, ManagedConnection{ .dest = fd2 });
try self.db.put(fd2, ManagedConnection{ .dest = fd1 });
}
pub fn set_callbacks(self: *Self, fd: i32, in: *const fn (origin: i32, mcontent: [*]u8, mlen: *u32) CBEventType, out: *const fn (origin: i32, mcontent: [*]const u8, mlen: u32) CBEventType) !void {
var managedconnection = self.db.get(fd) orelse return error.unregisteredFD;
managedconnection.in = in;
managedconnection.out = out;
}
/// Dig the "db" hashmap, perform "in" fn, may provide a message.
/// Errors from the "in" fn are reported as Zig errors.
pub fn read(self: *Self, fd: i32) !?Message {
// assert there is an entry with this fd as a key.
var managedconnection = self.db.get(fd) orelse return error.unregisteredFD;
var buffer = [_]u8{0} ** 100000; // TODO: buffer size
var message_size: u32 = @truncate(u32, buffer.len);
var r: CBEventType = managedconnection.in(fd, &buffer, &message_size);
switch (r) {
// The message should be ignored (protocol specific).
CBEventType.IGNORE => {
return null;
},
CBEventType.NO_ERROR => {
// TODO: read message
// TODO: better allocator?
// TODO: better errors?
var message: Message = Message.read(managedconnection.dest, buffer[0..message_size], std.heap.c_allocator) catch {
return error.generic;
};
return message;
},
CBEventType.FD_CLOSING => {
return error.closeFD;
},
// Generic error, or the message was read but with errors.
CBEventType.ERROR => {
return error.generic;
},
}
unreachable;
}
/// Dig the "db" hashmap and perform "out" fn.
/// Errors from the "out" fn are reported as Zig errors.
pub fn write(self: *Self, message: Message) !void {
// assert there is an entry with this fd as a key.
var managedconnection = self.db.get(message.fd) orelse return error.unregisteredFD;
var buffer = [_]u8{0} ** 100000; // TODO: buffer size
var fbs = std.io.fixedBufferStream(&buffer);
var writer = fbs.writer();
// returning basic errors, no details.
_ = message.write(writer) catch return error.generic;
var written = fbs.getWritten();
var r = managedconnection.out(message.fd, written.ptr, @truncate(u32, written.len));
switch (r) {
// The message should be ignored (protocol specific).
// No error. A message was generated.
CBEventType.NO_ERROR => {
return;
},
CBEventType.FD_CLOSING => {
return error.closeFD;
},
// Generic error, or the message was read but with errors.
CBEventType.IGNORE, CBEventType.ERROR => {
return error.generic;
},
}
unreachable;
}
/// From a message to read on a socket to an Event.
pub fn handle_event_read(self: *Self, index: usize, fd: i32) Event {
var message: ?Message = null;
message = self.read(fd) catch |err| switch (err) {
error.closeFD => {
return Event.init(Event.Type.DISCONNECTION, index, fd, null);
},
error.unregisteredFD, error.generic => {
return Event.init(Event.Type.ERROR, index, fd, null);
},
};
return Event.init(Event.Type.SWITCH_RX, index, fd, message);
}
pub fn handle_event_write(self: *Self, index: usize, message: Message) Event {
var fd = message.fd;
self.write(message) catch |err| switch (err) {
error.closeFD => {
return Event.init(Event.Type.DISCONNECTION, index, fd, null);
},
error.unregisteredFD, error.generic => {
return Event.init(Event.Type.ERROR, index, fd, null);
},
};
return Event.init(Event.Type.SWITCH_TX, index, fd, null);
}
/// Simple wrapper around self.db.get.
pub fn getDest(self: *Self, fd: i32) !i32 {
return self.db.get(fd).?.dest;
}
/// Remove both entries (client and service) from the DB.
pub fn nuke(self: *Self, fd: i32) void {
if (self.db.fetchSwapRemove(fd)) |kv| {
_ = self.db.swapRemove(kv.value.dest);
}
}
};
const ManagedConnection = struct {
dest: i32,
in: *const fn (origin: i32, mcontent: [*]u8, mlen: *u32) CBEventType = default_in,
out: *const fn (origin: i32, mcontent: [*]const u8, mlen: u32) CBEventType = default_out,
};
test "creation and display" {
const config = .{ .safety = true };
var gpa = std.heap.GeneralPurposeAllocator(config){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var switchdb = SwitchDB.init(allocator);
defer switchdb.deinit();
try switchdb.db.put(5, ManagedConnection{ .dest = 6 });
try switchdb.db.put(6, ManagedConnection{ .dest = 5 });
try print_eq("{ (5,6)(6,5) }", .{switchdb});
}
fn successful_in(_: i32, mcontent: [*]u8, mlen: *u32) CBEventType {
var m = Message.init(8, std.heap.c_allocator, "coucou") catch unreachable;
defer m.deinit();
var fbs = std.io.fixedBufferStream(mcontent[0..mlen.*]);
var writer = fbs.writer();
const bytes_written = m.write(writer) catch unreachable;
mlen.* = @truncate(u32, bytes_written);
return CBEventType.NO_ERROR;
}
fn successful_out(_: i32, _: [*]const u8, _: u32) CBEventType {
return CBEventType.NO_ERROR;
}
test "successful exchanges" {
const config = .{ .safety = true };
var gpa = std.heap.GeneralPurposeAllocator(config){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var switchdb = SwitchDB.init(allocator);
defer switchdb.deinit();
try switchdb.db.put(5, ManagedConnection{ .dest = 6, .in = successful_in, .out = successful_out });
try switchdb.db.put(6, ManagedConnection{ .dest = 5, .in = successful_in, .out = successful_out });
// should return a new message (hardcoded: fd 8, payload "coucou")
var event_1: Event = switchdb.handle_event_read(1, 5);
if (event_1.m) |m| {
m.deinit();
} else {
return error.NoMessage;
}
// should return a new message (hardcoded: fd 8, payload "coucou")
var event_2: Event = switchdb.handle_event_read(1, 6);
if (event_2.m) |m| {
m.deinit();
} else {
return error.NoMessage;
}
var message = try Message.init(6, allocator, "coucou");
var event_3 = switchdb.handle_event_write(5, message);
if (event_3.m) |_| {
return error.ShouldNotCarryMessage;
}
}
fn unsuccessful_in(_: i32, _: [*]const u8, _: *u32) CBEventType {
return CBEventType.ERROR;
}
fn unsuccessful_out(_: i32, _: [*]const u8, _: u32) CBEventType {
return CBEventType.ERROR;
}
test "unsuccessful exchanges" {
const config = .{ .safety = true };
var gpa = std.heap.GeneralPurposeAllocator(config){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var switchdb = SwitchDB.init(allocator);
defer switchdb.deinit();
try switchdb.db.put(5, ManagedConnection{ .dest = 6, .in = unsuccessful_in, .out = unsuccessful_out });
try switchdb.db.put(6, ManagedConnection{ .dest = 5, .in = unsuccessful_in, .out = unsuccessful_out });
// should return a new message (hardcoded: fd 8, payload "coucou")
var event_1: Event = switchdb.handle_event_read(1, 5);
if (event_1.m) |_| {
return error.ShouldNotCarryMessage;
}
// should return a new message (hardcoded: fd 8, payload "coucou")
var event_2: Event = switchdb.handle_event_read(1, 6);
if (event_2.m) |_| {
return error.ShouldNotCarryMessage;
}
var message = try Message.init(6, allocator, "coucou");
var event_3 = switchdb.handle_event_write(5, message);
if (event_3.m) |_| {
return error.ShouldNotCarryMessage;
}
}
test "nuke 'em" {
const config = .{ .safety = true };
var gpa = std.heap.GeneralPurposeAllocator(config){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var switchdb = SwitchDB.init(allocator);
defer switchdb.deinit();
try switchdb.db.put(5, ManagedConnection{ .dest = 6, .in = unsuccessful_in, .out = unsuccessful_out });
try switchdb.db.put(6, ManagedConnection{ .dest = 5, .in = unsuccessful_in, .out = unsuccessful_out });
try testing.expect(switchdb.db.count() == 2);
switchdb.nuke(5);
try testing.expect(switchdb.db.count() == 0);
}
fn default_in(origin: i32, mcontent: [*]u8, mlen: *u32) CBEventType {
// This may be kinda hacky, idk.
var stream: net.Stream = .{ .handle = origin };
var packet_size: usize = stream.read(mcontent[0..mlen.*]) catch return CBEventType.ERROR;
// Let's handle this as a disconnection.
if (packet_size < 4) {
log.debug("message is less than 4 bytes ({} bytes)", .{packet_size});
return CBEventType.FD_CLOSING;
}
mlen.* = @truncate(u32, packet_size);
return CBEventType.NO_ERROR;
}
fn default_out(fd: i32, mcontent: [*]const u8, mlen: u32) CBEventType {
// Message contains the fd, no need to search for the right structure to copy,
// let's just recreate a Stream from the fd.
var to_send = mcontent[0..mlen];
var stream = net.Stream{ .handle = fd };
_ = stream.write(to_send) catch return CBEventType.ERROR;
return CBEventType.NO_ERROR;
}