cosmopolitan/tool/net/echoserver.c

301 lines
10 KiB
C
Raw Normal View History

2020-06-15 14:18:57 +00:00
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
vi: set net ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi
Copyright 2020 Justine Alexandra Roberts Tunney
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; version 2 of the License.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
02110-1301 USA
*/
#include "libc/alg/arraylist.internal.h"
#include "libc/bits/safemacros.h"
2020-06-15 14:18:57 +00:00
#include "libc/calls/calls.h"
#include "libc/calls/struct/iovec.h"
#include "libc/errno.h"
#include "libc/fmt/fmt.h"
#include "libc/log/check.h"
#include "libc/log/log.h"
#include "libc/runtime/gc.h"
#include "libc/runtime/interruptiblecall.h"
#include "libc/runtime/runtime.h"
#include "libc/sock/sock.h"
#include "libc/stdio/stdio.h"
#include "libc/str/str.h"
#include "libc/sysv/consts/af.h"
#include "libc/sysv/consts/exit.h"
#include "libc/sysv/consts/ipproto.h"
#include "libc/sysv/consts/msg.h"
#include "libc/sysv/consts/poll.h"
#include "libc/sysv/consts/so.h"
#include "libc/sysv/consts/sock.h"
#include "libc/sysv/consts/sol.h"
#include "libc/x/x.h"
#include "third_party/getopt/getopt.h"
/**
* @fileoverview Asynchronous TCP/UDP Echo Server.
*
* make -j8 o/default/tool/net/echoserver.com
* o/default/tool/net/echoserver.com udp:0.0.0.0:7 tcp:0.0.0.0:7
*/
enum SocketKind {
kSocketServer,
kSocketClient,
};
struct Message {
struct iovec data;
struct sockaddr_in dest;
uint32_t destsize;
};
struct Messages {
size_t i, n;
struct Message *p;
};
struct Socket {
int64_t fd;
enum SocketKind kind;
int type;
int protocol;
struct sockaddr_in addr;
struct Messages egress; /* LIFO */
};
struct Sockets {
size_t i, n;
struct Socket *p;
};
struct Polls {
size_t i, n;
struct pollfd *p;
};
struct Sockets g_sockets;
struct Polls g_polls;
nodiscard char *DescribeAddress(struct sockaddr_in *addr) {
char ip4buf[16];
return xasprintf("%s:%hu",
inet_ntop(addr->sin_family, &addr->sin_addr.s_addr, ip4buf,
sizeof(ip4buf)),
ntohs(addr->sin_port));
}
nodiscard char *DescribeSocket(struct Socket *s) {
return xasprintf("%s:%s", s->protocol == IPPROTO_UDP ? "udp" : "tcp",
gc(DescribeAddress(&s->addr)));
}
wontreturn void ShowUsageAndExit(bool iserror) {
2020-06-15 14:18:57 +00:00
FILE *f = iserror ? stderr : stdout;
int rc = iserror ? EXIT_FAILURE : EXIT_SUCCESS;
fprintf(f, "%s: %s %s\n", "Usage", g_argv[0], "PROTOCOL:ADDR:PORT...");
exit(rc);
}
void GetFlags(int argc, char *argv[]) {
int opt;
while ((opt = getopt(argc, argv, "h")) != -1) {
switch (opt) {
case 'h':
ShowUsageAndExit(false);
default:
ShowUsageAndExit(true);
}
}
if (optind == argc) ShowUsageAndExit(true);
}
void AddSocket(const struct Socket *s) {
struct pollfd pfd;
pfd.fd = s->fd;
pfd.events = POLLIN;
pfd.revents = 0;
CHECK_NE(-1L, append(&g_sockets, s));
CHECK_NE(-1L, append(&g_polls, &pfd));
}
void RemoveSocket(size_t i) {
DCHECK_LT(i, g_sockets.i);
LOGF("removing: %s", gc(DescribeSocket(&g_sockets.p[i])));
CHECK_NE(-1, close(g_sockets.p[i].fd));
while (g_sockets.p[i].egress.i) {
free(g_sockets.p[i].egress.p[g_sockets.p[i].egress.i - 1].data.iov_base);
}
memcpy(&g_sockets.p[i], &g_sockets.p[i + 1],
(intptr_t)&g_sockets.p[g_sockets.i] - (intptr_t)&g_sockets.p[i + 1]);
memcpy(&g_polls.p[i], &g_polls.p[i + 1],
(intptr_t)&g_polls.p[g_polls.i] - (intptr_t)&g_polls.p[i + 1]);
g_sockets.i--;
g_polls.i--;
}
void GetListeningAddressesFromCommandLine(int argc, char *argv[]) {
int i;
for (i = optind; i < argc; ++i) {
struct Socket server;
memset(&server, 0, sizeof(server));
char scheme[4];
unsigned char *ip4 = (unsigned char *)&server.addr.sin_addr.s_addr;
uint16_t port;
if (sscanf(argv[i], "%3s:%hhu.%hhu.%hhu.%hhu:%hu", scheme, &ip4[0], &ip4[1],
&ip4[2], &ip4[3], &port) != 6) {
fprintf(stderr, "error: bad ip4 uri\n");
ShowUsageAndExit(true);
}
server.fd = -1;
server.kind = kSocketServer;
server.addr.sin_family = AF_INET;
server.addr.sin_port = htons(port);
if (strcasecmp(scheme, "tcp") == 0) {
server.type = SOCK_STREAM;
server.protocol = IPPROTO_TCP;
} else if (strcasecmp(scheme, "udp") == 0) {
server.type = SOCK_DGRAM;
server.protocol = IPPROTO_UDP;
} else {
fprintf(stderr, "%s: %s\n", "error", "bad scheme (should be tcp or udp)");
ShowUsageAndExit(true);
}
AddSocket(&server);
}
}
void BeginListeningForIncomingTraffic(void) {
size_t i;
for (i = 0; i < g_sockets.i; ++i) {
int yes = 1;
struct Socket *s = &g_sockets.p[i];
CHECK_NE(-1L,
(g_polls.p[i].fd = s->fd = socket(
s->addr.sin_family, s->type | SOCK_NONBLOCK, s->protocol)));
CHECK_NE(-1L,
setsockopt(s->fd, SOL_SOCKET, SO_REUSEPORT, &yes, sizeof(yes)));
CHECK_NE(-1, bind(s->fd, &s->addr, sizeof(s->addr)));
if (s->protocol == IPPROTO_TCP) {
CHECK_NE(-1, listen(s->fd, 1));
}
uint32_t addrsize = sizeof(s->addr);
CHECK_NE(-1, getsockname(s->fd, &s->addr, &addrsize));
LOGF("listening on %s", gc(DescribeSocket(s)));
}
}
void AcceptConnection(size_t i) {
struct Socket *server = &g_sockets.p[i];
struct Socket client;
memset(&client, 0, sizeof(client));
client.kind = kSocketClient;
client.type = server->type;
client.protocol = server->protocol;
uint32_t addrsize = sizeof(client.addr);
CHECK_NE(-1L, (client.fd = accept4(server->fd, &client.addr, &addrsize,
SOCK_NONBLOCK)));
LOGF("%s accepted %s", gc(DescribeSocket(server)),
gc(DescribeSocket(&client)));
AddSocket(&client);
}
bool ReceiveData(size_t i) {
ssize_t got;
struct Message msg;
bool isudp = g_sockets.p[i].protocol == IPPROTO_UDP;
memset(&msg, 0, sizeof(msg));
msg.destsize = sizeof(msg.dest);
msg.data.iov_len = PAGESIZE;
msg.data.iov_base = xmalloc(msg.data.iov_len);
CHECK_NE(-1L, (got = recvfrom(g_sockets.p[i].fd, msg.data.iov_base,
msg.data.iov_len, 0, isudp ? &msg.dest : NULL,
isudp ? &msg.destsize : NULL)));
if (0 < got && got <= msg.data.iov_len) {
LOGF("%s received %lu bytes from %s", gc(DescribeSocket(&g_sockets.p[i])),
got, gc(DescribeAddress(&msg.dest)));
msg.data.iov_base = xrealloc(msg.data.iov_base, (msg.data.iov_len = got));
append(&g_sockets.p[i].egress, &msg);
g_polls.p[i].events |= POLLOUT;
return true;
} else {
RemoveSocket(i);
free_s(&msg.data.iov_base);
return false;
}
}
void SendData(size_t i) {
ssize_t sent;
struct Socket *s = &g_sockets.p[i];
struct Message *msg = &s->egress.p[s->egress.i - 1];
bool isudp = s->protocol == IPPROTO_UDP;
DCHECK(s->egress.i);
CHECK_NE(-1L, (sent = sendto(s->fd, msg->data.iov_base, msg->data.iov_len, 0,
isudp ? &msg->dest : NULL,
isudp ? msg->destsize : 0)));
LOGF("%s sent %lu bytes to %s", gc(DescribeSocket(s)), msg->data.iov_len,
gc(DescribeAddress(&msg->dest)));
if (!(msg->data.iov_len -= min((size_t)sent, (size_t)msg->data.iov_len))) {
free_s(&msg->data.iov_base);
if (!--s->egress.i) {
g_polls.p[i].events &= ~POLLOUT;
}
}
}
void HandleSomeNetworkTraffic(void) {
size_t i;
int eventcount;
CHECK_GE((eventcount = poll(g_polls.p, g_polls.i, -1)), 0);
for (i = 0; eventcount && i < g_sockets.i; ++i) {
if (!g_polls.p[i].revents) continue;
--eventcount;
if (g_polls.p[i].revents & (POLLERR | POLLHUP | POLLNVAL)) {
CHECK_EQ(kSocketClient, g_sockets.p[i].kind);
RemoveSocket(i);
} else {
if (g_polls.p[i].revents & POLLIN) {
if (g_sockets.p[i].kind == kSocketServer &&
g_sockets.p[i].protocol == IPPROTO_TCP) {
AcceptConnection(i);
} else {
if (!ReceiveData(i)) continue;
}
}
if (g_polls.p[i].revents & POLLOUT) {
SendData(i);
}
}
}
}
void EchoServer(void) {
for (;;) HandleSomeNetworkTraffic();
}
int main(int argc, char *argv[]) {
STATIC_YOINK("isatty");
GetFlags(argc, argv);
GetListeningAddressesFromCommandLine(argc, argv);
BeginListeningForIncomingTraffic();
struct InterruptibleCall icall;
memset(&icall, 0, sizeof(icall));
interruptiblecall(&icall, (void *)EchoServer, 0, 0, 0, 0);
fputc('\r', stderr);
LOGF("%s", "shutting down...");
size_t i;
for (i = g_sockets.i; i; --i) RemoveSocket(i - 1);
return 0;
}