Home 实现一个内核态的 web server
Post
Cancel

实现一个内核态的 web server

前言

首先强调,这个轮子没有任何现实意义。Linux 内核态实现 web 服务器早在 1999 年就有人尝试过(并且毫无疑问失败了)。这里只是简单对比用户态和内核态的性能差距,或者当作是咬打火机[1]也行。

本文的实现以及测试均基于 Linux 内核版本 6.4.8,还有小部分内容没完善,不妨先随便看看。

[1] 历史上的咬打火机技术选型:In-kernel web server – Wikipedia

进展

(已更新)目前只提供单 kthread 的 HTTP server,以 kernel module 的形式使用。

目前可提供多 kthread、内核态 epoll 实现的 HTTP server,以 kernel module 的形式使用。但是必须要修改内核,见下方描述。


我本意是想要用 epoll 做用户态与内核态跑分对比,但是适配过程发现 epoll 实现根本就没打算留给内核态自己用,目前需要做的工作是:

  • [DONE] export symbol。修改内核,把 do_ 前缀的 create / ctl / wait 函数导出即可。
  • [DONE] 为 struct socket 提供 file。内核默认并不关联文件,可用 sock_alloc_file() 解决。
  • [DONE] 为 struct socket 关联 struct fd fdint fd。epoll 不仅接口不提供 file 支持,还要用 fdfdtable 绕一层(内核态里面用 fd 是否有点……),基本上就是只给用户态使用的意思。正在考虑是想办法开洞绕过 fd,或者内核上做改动(进一步导出更多符号?)[2]
  • [DONE] 处理销毁问题。关闭文件 private 实例、文件和 fd 使用 close_fd() 是最简单的做法。
  • [DONE] 处理权限问题。epoll_wait 假定了 events 必处于用户空间,因此需要注释掉这两行[3]
  • [DONE] 多线程测试适配。直接使用 REUSEPORT 进行 kthread 隔离就好了。
  • [DONE] 用户态版本适配。完成相对容易,只需用常见的 C 库/系统调用替换既有实现即可。
  • [TODO] 性能对比。目前做了第三方库以及用户态版本的简单对比,但还不够,需要写一个脚本。
  • [TODO] profile。光靠对比其实没啥说服力,但是前面还存在问题待办。

[2] 虽然 sock_map_fd() 非公开函数,但可以尝试 fd = get_unused_fd_flags(...)file = sock_alloc_file(...)fd_install(fd, file) 这一套组合作为替代选项。
[3] 这种修改是对内核安全有破坏的(丢失了 EFAULT 判断),只适用于本次测试。另据网络流言(不保真),Linux 4.x 时代是允许 kthread 访问的,到了 5.1 虽被禁止但仍可配置 CONFIG 解决,再后续的版本就不再允许。

内核态实现

makefile 部分

obj-m += server_kernel.o

KERNEL_SOURCE := /lib/modules/$(shell uname -r)/build
PWD := $(shell pwd)

default: kernel

all: kernel user

kernel:
	$(MAKE) -C $(KERNEL_SOURCE) M=$(PWD) modules

# Common kernel configs.
user:
	$(CC) server_user.c -O2 -pthread -std=gnu11 -o server_user \
	-fno-strict-aliasing -fno-common -fshort-wchar -funsigned-char \
	-Wundef -Werror=strict-prototypes -Wno-trigraphs \
	-Werror=implicit-function-declaration -Werror=implicit-int \
	-Werror=return-type -Wno-format-security

clean:
	$(MAKE) -C $(KERNEL_SOURCE) M=$(PWD) clean
	rm -f server_user

不管是当前文件,还是此前的历史版本,都使用同一 Makefile 构建。

module 部分

kernel module 是 GNU C 实现,代码如下:

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/kthread.h>
#include <linux/eventpoll.h>
#include <linux/fdtable.h>
#include <linux/slab.h>
// Modified epoll header
#include <linux/fs.h>

#define SERVER_PORT 8848
#define CONTEXT_MAX_THREAD 32
#define NO_FAIL(reason, err_str, finally) \
  if(unlikely(ret < 0)) {err_str = reason; goto finally;}

static struct thread_context {
    struct socket *server_socket;
    struct task_struct *thread;
} thread_contexts[CONTEXT_MAX_THREAD];

// Note that socket_context is allocated by each server instance.
// Thus it has no false-sharing problem.
struct socket_context {
    struct socket *socket;
    // Counts of \r\n.
    // `wrk` must send three \r\n per request by default.
    size_t _r_n;
    // Pending responses, consumed on EPOLLOUT.
    size_t responses;
};

static int num_threads = 1;
module_param(num_threads, int, 0644);
MODULE_PARM_DESC(num_threads, "Number of threads");

static struct socket* create_server_socket(void) {
    int ret;
    const char *err_msg;
    struct sockaddr_in server_addr;
    int optval = 1;
    sockptr_t koptval = KERNEL_SOCKPTR(&optval);
    struct socket *server_socket;

    ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &server_socket);
    NO_FAIL("Failed to create socket", err_msg, done);

    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);

    ret = sock_setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, koptval, sizeof(optval));
    NO_FAIL("sock_setsockopt(SO_REUSEADDR)", err_msg, done);
    ret = sock_setsockopt(server_socket, SOL_SOCKET, SO_REUSEPORT, koptval, sizeof(optval));
    NO_FAIL("sock_setsockopt(SO_REUSEPORT)", err_msg, done);

    ret = kernel_bind(server_socket, (struct sockaddr *)&server_addr, sizeof(server_addr));
    NO_FAIL("kernel_bind", err_msg, done);

    ret = kernel_listen(server_socket, 1024);
    NO_FAIL("kernel_listen", err_msg, done);

    return server_socket;

done:
    pr_err("%s", err_msg);
    if(server_socket) sock_release(server_socket);
    return NULL;
}


static int make_fd_and_file(struct socket *sock) {
    struct file *file;
    int fd;
    file = sock_alloc_file(sock, 0 /* NONBLOCK? */, NULL);
    if(unlikely(IS_ERR(file))) return -1;
    fd = get_unused_fd_flags(0);
    if(unlikely(fd < 0)) {
        fput(file);
        return fd;
    }
    fd_install(fd, file);
    return fd;
}


static int update_event(int epfd, int ep_ctl_flag, __poll_t ep_event_flag, /// Epoll.
                        struct socket_context context[], int fd, struct socket *sock) { /// Sockets.
    int ret;
    struct epoll_event event;
    bool fd_is_ready = (ep_ctl_flag != EPOLL_CTL_ADD);

    if(!fd_is_ready) fd = make_fd_and_file(sock);
    if(unlikely(fd < 0)) {
        pr_warn("fd cannot allocate: %d\n", fd);
        return fd;
    }
    event.data = fd;
    event.events = ep_event_flag;
    ret = do_epoll_ctl(epfd, ep_ctl_flag, fd, &event, false /* true for io_uring only */);
    if(unlikely(ret < 0)) {
        pr_warn("do_epoll_ctl: %d\n", ret);
        return ret;
    }
    if(!fd_is_ready) context[fd].socket = sock;
    return fd;
}


static void dump_event(struct epoll_event *e) {
    bool epollin  = e->events & EPOLLIN;
    bool epollout = e->events & EPOLLOUT;
    bool epollhup = e->events & EPOLLHUP;
    bool epollerr = e->events & EPOLLERR;
    __u64 data = e->data;
    pr_info("dump: %d%d%d%d %llu\n", epollin, epollout, epollhup, epollerr, data);
}


static int wrk_parse(struct socket_context *context, const char *buffer, int nread) {
    int _r_n = context->_r_n;
    int requests = 0;
    for(const char *c = buffer; c != buffer + nread; c++) {
        if(*c == '\r' || *c == '\n') {
            // `wrk` must send three \r\n per request by default.
            if(++_r_n == 6) ++requests, _r_n = 0;
        }
    }
    context->_r_n = _r_n;
    // 1:1 response to request.
    context->responses += requests;
    return requests;
}


static void event_loop(int epfd, struct epoll_event *events, const int nevents,
                       int server_fd, struct socket *server_socket, struct socket_context sockets[],
                       char *read_buffer, const size_t READ_BUFFER, struct kvec *request_vec,
                       const int content_len, struct kvec response_vec[], const int MAX_RESPONSES,
                       struct msghdr *msg) {
    int ret;

    __poll_t next_event;
    __poll_t current_event;
    int client_fd;
    struct socket_context *client_context;
    struct socket *client_socket;

    int requests;
    int responses;

    for(struct epoll_event *e = &events[0]; e != &events[nevents]; e++) {
        // dump_event(e);
        if(e->data == server_fd) {
            kernel_accept(server_socket, &client_socket, 0);
            update_event(epfd, EPOLL_CTL_ADD, EPOLLIN | EPOLLHUP, sockets, -1, client_socket);
        } else {
            current_event = e->events;
            next_event = e->events;
            client_fd = e->data;
            client_context = &sockets[client_fd];
            client_socket = client_context->socket;
            if(e->events & EPOLLIN) {
                ret = kernel_recvmsg(client_socket, msg, request_vec, 1, READ_BUFFER, 0);
                // Fast check: Maybe a FIN packet and nothing is buffered (!EPOLLOUT).
                if(ret == 0 && e->events == EPOLLIN) {
                    e->events = EPOLLHUP;
                // May be an RST packet.
                } else if(unlikely(ret < 0)) {
                    if(ret != -EINTR) e->events = EPOLLHUP;
                // Slower path, may call (do_)epoll_ctl().
                } else {
                    requests = wrk_parse(client_context, read_buffer, ret);
                    // Keep reading if there is no complete request.
                    // Otherwise disable EPOLLIN.
                    // FIXME. always enable? Cost more "syscall"s?
                    if(requests) next_event &= ~EPOLLIN;
                    // There are some pending responses to be send.
                    if(client_context->responses) next_event |= EPOLLOUT;
                }
            }
            if(e->events & EPOLLOUT) {
                BUG_ON(client_context->responses == 0);
                responses = client_context->responses;
                if(responses >= MAX_RESPONSES) {
                    responses = MAX_RESPONSES - 1;
                }
                // >= 0
                client_context->responses -= responses;

                // <del>Short write?</del>
                // No short write in blocking mode. See UNP book section 3.9 for more details.
                ret = kernel_sendmsg(client_socket, msg, &response_vec[0],
                        responses, content_len * responses);
                if(ret < 0) {
                    pr_warn("kernel_sendmsg: %d\n", ret);
                    if(ret != -EINTR) e->events = EPOLLHUP;
                } else {
                    if(!client_context->responses) next_event &= ~EPOLLOUT;
                    next_event |= EPOLLIN;
                }
            }
            if((e->events & EPOLLHUP) && !(e->events & EPOLLIN)) {
                ret = update_event(epfd, EPOLL_CTL_DEL, 0, sockets, client_fd, client_socket);
                if(unlikely(ret < 0)) pr_warn("update_event[HUP]: %d\n", ret);
                close_fd(client_fd);
                memset(client_context, 0, sizeof (struct socket_context));
            }
            // Not necessary to compare the current event,
            // but avoid duplicate "syscall".
            if(e->events != EPOLLHUP && current_event != next_event) {
                ret = update_event(epfd, EPOLL_CTL_MOD, next_event,
                                    sockets, client_fd, client_socket);
                if(unlikely(ret < 0)) pr_warn("update_event[~HUP]: %d\n", ret);
            }
        }
    }
}

static int server_thread(void *data) {
    /// Control flows.

    int ret;
    const char *err_msg;
    struct thread_context *context = data;

    /// Sockets.

    int server_fd;
    struct socket *server_socket = context->server_socket;
    // Limited by fd size. 1024 is enough for test.
    // Usage: sockets[fd].socket = socket_ptr.
    const size_t SOCKETS = 1024;
    struct socket_context *sockets = NULL;

    /// Buffers.

    const size_t READ_BUFFER = 4096;
    char *read_buffer = NULL;
    char *response_content =
        "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
    const int response_content_len = strlen(response_content);
    const int MAX_RESPONSES = 32;
    struct kvec request_vec;
    struct kvec response_vec[32] = {
        [0 ... 31] = {
            .iov_base = response_content,
            .iov_len = response_content_len,
        }
    };
    struct msghdr msg;


    /// Epoll.

    int epfd = -1;
    const size_t EVENTS = 1024;
    int nevents;
    struct epoll_event *events = NULL;

    memset(&msg, 0, sizeof msg);
    sockets = kmalloc_array(SOCKETS, sizeof(struct socket_context), GFP_KERNEL | __GFP_ZERO);
    events = kmalloc_array(EVENTS, sizeof(struct epoll_event), GFP_KERNEL);
    read_buffer = kmalloc(READ_BUFFER, GFP_KERNEL);
    ret = (sockets && events && read_buffer) ? 0 : -ENOMEM;
    NO_FAIL("kmalloc[s|e|d]", err_msg, done);
    request_vec.iov_base = read_buffer;
    request_vec.iov_len = READ_BUFFER;

    /////////////////////////////////////////////////////////////////////////////////

    // Debug only.
    (void)dump_event;

    allow_signal(SIGKILL);
    allow_signal(SIGTERM);

    ret = do_epoll_create(0);
    NO_FAIL("do_epoll_create", err_msg, done);
    epfd = ret;

    ret = update_event(epfd, EPOLL_CTL_ADD, EPOLLIN, sockets, -1, server_socket);
    NO_FAIL("update_event", err_msg, done);
    server_fd = ret;

    while(!kthread_should_stop()) {
        ret = do_epoll_wait(epfd, events, EVENTS, NULL /* INF ms */);
        NO_FAIL("do_epoll_wait", err_msg, done);
        nevents = ret;
        event_loop(epfd, events, nevents, // Epoll
                   server_fd, server_socket, sockets, // Socket
                   read_buffer, READ_BUFFER, &request_vec, // READ
                   response_content_len, response_vec, MAX_RESPONSES, // WRITE
                   &msg); // Iterator
    }

done:
    if(ret < 0) pr_err("%s: %d\n", err_msg, ret);
    if(~epfd) close_fd(epfd);
    if(events) kfree(events);
    if(read_buffer) kfree(read_buffer);
    // Server is included.
    if(sockets) {
        for(int i = 0; i < SOCKETS; i++) {
            if(sockets[i].socket) close_fd(i);
        }
        kfree(sockets);
    }
    context->thread = NULL;
    return ret;
}

static int each_server_init(struct thread_context *context) {
    context->server_socket = create_server_socket();
    if(!context->server_socket) {
        return -1;
    }

    context->thread = kthread_run(server_thread, context, "in_kernel_web_server");

    if(IS_ERR(context->thread)) {
        pr_err("Failed to create thread\n");
        return PTR_ERR(context->thread);
    }

    pr_info("worker thread id: %d\n", context->thread->pid);
    return 0;
}

static void each_server_exit(struct thread_context *context) {
    struct task_struct *thread = context->thread;
    if(thread) {
        send_sig(SIGTERM, thread, 1);
        kthread_stop(thread);
    }
}


static int __init simple_web_server_init(void) {
    int threads = num_threads;
    if(threads >= CONTEXT_MAX_THREAD || threads < 1) {
        pr_err("num_threads < (CONTEXT_MAX_THREAD=32)\n");
        return -1;
    }
    for(int i = 0; i < threads; ++i) {
        if(each_server_init(&thread_contexts[i])) {
            pr_err("Boot failed\n");
            for(--i; ~i; i--) {
                each_server_exit(&thread_contexts[i]);
            }
            return -1;
        }
    }
    pr_info("Simple Web Server Initialized\n");
    return 0;
}


static void __exit simple_web_server_exit(void) {
    struct thread_context *context;
    int threads = num_threads;
    for(context = &thread_contexts[0]; threads--; context++) {
        each_server_exit(context);
    }
    pr_info("Simple Web Server Exited\n");
}


module_init(simple_web_server_init);
module_exit(simple_web_server_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Caturra");
MODULE_DESCRIPTION("Simple In-Kernel Web Server");

不看代码点这里

做点简单的说明。这份内核态 HTTP server 实现基于 历史版本 3 进行改进。除了前面 kthread 和 epoll 都配备以外,还提供了多 CPU 支持,现在可以在 insmod 时指定 num_threads;以及增加一个 wrk 解析器,现在已经可以正确的进行压测。出于内核动态内存管理的复杂性,客户端加服务端 fd 总计不得超过 1024;同样的理由,线程数不超过 32。(当然这两个想改多大都行)

这里仍然是个简单的静态 HTTP server,浏览器输入 127.0.0.1:8848,就能返回一串 hello world。

对比数据 1

由于用户态版本还没写好,这里挑选了一个流行的 C++ 网络库 libhv 来对比。测试代码如下:

NOTE: 用户态版本已经完成,见下方用户态适配

#include "HttpServer.h"
using namespace hv;

int main() {
    HttpService router;
    router.GET("/", [](HttpRequest* req, HttpResponse* resp) {
        return resp->String("Hello, world!");
    });

    HttpServer server(&router);
    server.setPort(8080);
    server.setThreadNum(8);
    server.run();
    return 0;
}

测试结果均基于 gcc-13,-O2 优化级别,使用 wrk 做负载:

// libhv
caturra@LAPTOP-RU7SB7FE:~$ wrk -t12 -c400 -d10s --latency "http://127.0.0.1:8080/"
Running 10s test @ http://127.0.0.1:8080/
  12 threads and 400 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency     2.90ms  723.13us  12.06ms   63.18%
    Req/Sec    11.45k   746.28    15.77k    74.79%
  Latency Distribution
     50%    3.02ms
     75%    3.42ms
     90%    3.74ms
     99%    4.44ms
  1378550 requests in 10.10s, 210.35MB read
Requests/sec: 136485.25
Transfer/sec:     20.83MB

// in-kernel-web-server
// insmod server_kernel.ko num_threads=8
caturra@LAPTOP-RU7SB7FE:~/in_kernel_web_server$ wrk -t12 -c400 -d10s --latency "http://127.0.0.1:8848/"
Running 10s test @ http://127.0.0.1:8848/
  12 threads and 400 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency   583.60us    2.09ms  57.67ms   95.59%
    Req/Sec   151.83k    41.89k  214.34k    54.02%
  Latency Distribution
     50%  131.00us
     75%  257.00us
     90%    0.85ms
     99%    9.94ms
  18259600 requests in 10.10s, 0.88GB read
Requests/sec: 1807901.24
Transfer/sec:     89.66MB

内核态实现能快 10 倍以上,差距似乎有点夸张。像这种简单事务,单机 QPS 的下限应该是 100 万左右,但是这里的第三方库只有 13 万,不知道对方的实现是否有问题。不过反过来看,内核态实现的吞吐厉害了,(最大)延迟也就有点萎了。

先别急着 profile,等我把自己的代码搬回用户态实现再公平对比吧,这样可以避免其他干扰项。

用户态适配

花了点时间将用户态版本也做出来了(不看代码就跳到下一章吧):

#include <unistd.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/errno.h>
#include <netinet/in.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <memory.h>
#include <assert.h>

// https://elixir.bootlin.com/linux/v6.4.8/source/include/linux/compiler.h#L76
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)

#define pr_info(...) pr_warn(__VA_ARGS__)
#define pr_warn(...) pr_err(__VA_ARGS__)
#define pr_err(...) fprintf(stderr, __VA_ARGS__)

#define SERVER_PORT 8848
#define CONTEXT_MAX_THREAD 32
#define NO_FAIL(reason, err_str, finally) \
  if(unlikely(ret < 0)) {err_str = reason; goto finally;}
#define USERSPACE_UNSED

static struct thread_context {
    int server_fd;
    pthread_t thread;
} thread_contexts[CONTEXT_MAX_THREAD];

struct socket_context {
    void *_placeholder_;
    size_t _r_n;
    size_t responses;
};

// 1st argument.
static int num_threads = 1;
// 2nd argument. Set any non-zero number to enable zerocopy feature.
// https://www.kernel.org/doc/html/v6.4/networking/msg_zerocopy.html
static int zerocopy_flag = 0;

static int create_server_socket(void) {
    int ret;
    int fd = -1;
    const char *err_msg;
    struct sockaddr_in server_addr;
    int optval = 1;

    ret = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
    NO_FAIL("socket", err_msg, done);
    fd = ret;

    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);

    ret = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
    NO_FAIL("setsockopt(SO_REUSEADDR)", err_msg, done);
    ret = setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
    NO_FAIL("setsockopt(SO_REUSEPORT)", err_msg, done);

    ret = bind(fd, (struct sockaddr *)&server_addr, sizeof(server_addr));
    NO_FAIL("bind", err_msg, done);

    ret = listen(fd, 1024);
    NO_FAIL("listen", err_msg, done);

    return fd;

done:
    pr_err("%s: %d, %d\n", err_msg, ret, errno);
    if(~fd) close(fd);
    return ret;
}


static int make_fd_and_file(int fd) {
    return fd;
}


static int update_event(int epfd, int ep_ctl_flag, uint32_t ep_event_flag, /// Epoll.
                        struct socket_context context[], int fd) { /// Sockets.
    int ret;
    struct epoll_event event;
    bool fd_is_ready = (ep_ctl_flag != EPOLL_CTL_ADD);

    if(!fd_is_ready) fd = make_fd_and_file(fd);
    if(unlikely(fd < 0)) {
        pr_warn("fd cannot allocate: %d\n", fd);
        return fd;
    }
    event.data.fd = fd;
    event.events = ep_event_flag;
    ret = epoll_ctl(epfd, ep_ctl_flag, fd, &event);
    if(unlikely(ret < 0)) {
        pr_warn("epoll_ctl: %d\n", ret);
        return ret;
    }
    if(!fd_is_ready) context[fd]._placeholder_ = (void*)1;
    return fd;
}


static void dump_event(struct epoll_event *e) {
    bool epollin  = e->events & EPOLLIN;
    bool epollout = e->events & EPOLLOUT;
    bool epollhup = e->events & EPOLLHUP;
    bool epollerr = e->events & EPOLLERR;
    int data = e->data.fd;
    pr_info("dump: %d%d%d%d %d\n", epollin, epollout, epollhup, epollerr, data);
}


static int wrk_parse(struct socket_context *context, const char *buffer, int nread) {
    int _r_n = context->_r_n;
    int requests = 0;
    for(const char *c = buffer; c != buffer + nread; c++) {
        if(*c == '\r' || *c == '\n') {
            // `wrk` must send three \r\n per request by default.
            if(++_r_n == 6) ++requests, _r_n = 0;
        }
    }
    context->_r_n = _r_n;
    // 1:1 response to request.
    context->responses += requests;
    return requests;
}


static void event_loop(int epfd, struct epoll_event *events, const int nevents,
                       int server_fd, struct socket_context sockets[],
                       struct msghdr *read_msg, struct msghdr * write_msg,
                       const int MAX_RESPONSES) {
    int ret;

    uint32_t next_event;
    uint32_t current_event;
    int client_fd;
    struct socket_context *client_context;
    const char *read_buffer;

    int requests;
    int responses;

    for(struct epoll_event *e = &events[0]; e != &events[nevents]; e++) {
        if(e->data.fd == server_fd) {
            client_fd = accept(server_fd, NULL, NULL);
            update_event(epfd, EPOLL_CTL_ADD, EPOLLIN | EPOLLHUP, sockets, client_fd);
        } else {
            current_event = e->events;
            next_event = e->events;
            client_fd = e->data.fd;
            client_context = &sockets[client_fd];
            if(e->events & EPOLLIN) {
                ret = recvmsg(client_fd, read_msg, 0);
                // Fast check: Maybe a FIN packet and nothing is buffered (!EPOLLOUT).
                if(ret == 0 && e->events == EPOLLIN) {
                    e->events = EPOLLHUP;
                // May be an RST packet.
                } else if(unlikely(ret < 0)) {
                    if(errno != EINTR) e->events = EPOLLHUP;
                // Slower path, may call (do_)epoll_ctl().
                } else {
                    read_buffer = read_msg->msg_iov->iov_base;
                    requests = wrk_parse(client_context, read_buffer, ret);
                    // Keep reading if there is no complete request.
                    // Otherwise disable EPOLLIN.
                    // FIXME. always enable? Cost more "syscall"s?
                    if(requests) next_event &= ~EPOLLIN;
                    // There are some pending responses to be send.
                    if(client_context->responses) next_event |= EPOLLOUT;
                }
            }
            if(e->events & EPOLLOUT) {
                assert(client_context->responses != 0);
                responses = client_context->responses;
                if(responses >= MAX_RESPONSES) {
                    responses = MAX_RESPONSES - 1;
                }
                // >= 0
                client_context->responses -= responses;
                write_msg->msg_iovlen = responses;

                ret = sendmsg(client_fd, write_msg, zerocopy_flag);
                if(ret < 0) {
                    pr_warn("kernel_sendmsg: %d, %d\n", ret, errno);
                    if(errno != EINTR) e->events = EPOLLHUP;
                } else {
                    if(!client_context->responses) next_event &= ~EPOLLOUT;
                    next_event |= EPOLLIN;
                }
            }
            if((e->events & EPOLLHUP) && !(e->events & EPOLLIN)) {
                ret = update_event(epfd, EPOLL_CTL_DEL, 0, sockets, client_fd);
                if(unlikely(ret < 0)) pr_warn("update_event[HUP]: %d, %d\n", ret, errno);
                close(client_fd);
                memset(client_context, 0, sizeof (struct socket_context));
            }
            // Not necessary to compare the current event,
            // but avoid duplicate syscall.
            if(e->events != EPOLLHUP && current_event != next_event) {
                ret = update_event(epfd, EPOLL_CTL_MOD, next_event,
                                    sockets, client_fd);
                if(unlikely(ret < 0)) pr_warn("update_event[~HUP]: %d, %d\n", ret, errno);
            }
        }
    }
}


static void* server_thread(void *data) {
    /// Control flows.

    int ret;
    const char *err_msg;
    struct thread_context *context = data;

    /// Sockets.

    int server_fd = context->server_fd;
    // Limited by fd size. 1024 is enough for test.
    // Usage: sockets[fd].socket = socket_ptr.
    const size_t SOCKETS = 1024;
    struct socket_context *sockets = NULL;

    /// Buffers.

    const size_t READ_BUFFER = 4096;
    char *read_buffer = NULL;
    char *response_content =
        "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
    const int response_content_len = strlen(response_content);
    const int MAX_RESPONSES = 32;
    struct iovec request_vec;
    struct iovec response_vec[32] = {
        [0 ... 31] = {
            .iov_base = response_content,
            .iov_len = response_content_len,
        }
    };
    struct msghdr read_msg = {
        .msg_iov = &request_vec,
        .msg_iovlen = 1,
    };
    struct msghdr write_msg = {
        .msg_iov = response_vec,
        // .msg_iovlen =  /* Modified in event_loop(). */
    };

    /// Epoll.

    int epfd = -1;
    const size_t EVENTS = 1024;
    int nevents;
    struct epoll_event *events = NULL;

    sockets = malloc(SOCKETS * sizeof(struct socket_context));
    memset(sockets, 0, SOCKETS * sizeof(struct socket_context));
    events = malloc(EVENTS * sizeof(struct epoll_event));
    read_buffer = malloc(READ_BUFFER);
    ret = (sockets && events && read_buffer) ? 0 : -ENOMEM;
    NO_FAIL("kmalloc[s|e|d]", err_msg, done);
    request_vec.iov_base = read_buffer;
    request_vec.iov_len = READ_BUFFER;

    /////////////////////////////////////////////////////////////////////////////////

    // Debug only.
    (void)dump_event;

    ret = epoll_create(1);
    NO_FAIL("epoll_create", err_msg, done);
    epfd = ret;

    ret = update_event(epfd, EPOLL_CTL_ADD, EPOLLIN, sockets, server_fd);
    NO_FAIL("update_event", err_msg, done);
    server_fd = ret;

    // FIXME. NO check flag.
    while(true) {
        ret = epoll_wait(epfd, &events[0], EVENTS, -1);
        NO_FAIL("epoll_wait", err_msg, done);
        nevents = ret;
        event_loop(epfd, events, nevents, // Epoll
                   server_fd, sockets, // Socket
                   &read_msg, &write_msg, // Iterators
                   MAX_RESPONSES);
    }

done:
    if(ret < 0) pr_err("%s: %d, %d\n", err_msg, ret, errno);
    if(~epfd) close(epfd);
    if(events) free(events);
    if(read_buffer) free(read_buffer);
    // Server is included.
    if(sockets) {
        for(size_t i = 0; i < SOCKETS; i++) {
            if(sockets[i]._placeholder_) close(i);
        }
        free(sockets);
    }
    return NULL;
}

static int each_server_init(struct thread_context *context) {
    int ret;
    context->server_fd = create_server_socket();
    if(!context->server_fd) {
        return -1;
    }

    ret = pthread_create(&context->thread, NULL, server_thread, context);

    if(ret < 0) {
        pr_err("Failed to create thread\n");
        return ret;
    }

    pr_info("worker pthread id: %lu\n", context->thread);
    return 0;
}

static void each_server_exit(struct thread_context *context) {
    pthread_cancel(context->thread);
    pthread_join(context->thread, NULL);
}


static int simple_web_server_init(void) {
    int threads = num_threads;
    if(threads >= CONTEXT_MAX_THREAD || threads < 1) {
        pr_err("num_threads < (CONTEXT_MAX_THREAD=32)\n");
        return -1;
    }
    for(int i = 0; i < threads; ++i) {
        if(each_server_init(&thread_contexts[i])) {
            pr_err("Boot failed\n");
            for(--i; ~i; i--) {
                each_server_exit(&thread_contexts[i]);
            }
            return -1;
        }
    }
    pr_info("Simple Web Server Initialized\n");
    return 0;
}


static void simple_web_server_exit(void) {
    struct thread_context *context;
    int threads = num_threads;
    for(context = &thread_contexts[0]; threads--; context++) {
        each_server_exit(context);
    }
    pr_info("Simple Web Server Exited\n");
}


int main(int argc, char *argv[]) {
    num_threads = argc > 1 ? atoi(argv[1]) : 1;
    zerocopy_flag = argc > 2 ? atoi(argv[2]) : 0;
    zerocopy_flag = zerocopy_flag ? MSG_ZEROCOPY : 0;
    if(num_threads < 1 || num_threads >= CONTEXT_MAX_THREAD) {
        pr_err("num_threads < (CONTEXT_MAX_THREAD=32)\n");
        return 1;
    }
    if(zerocopy_flag == MSG_ZEROCOPY) {
        pr_info("Enable MSG_ZEROCOPY.\n");
    }
    simple_web_server_init();
    // Press any key...
    getchar();
    simple_web_server_exit();
}

由于代码和编译配置是基本等价于内核态版本,因此我们可以进行相对公平的性能测试!

对比数据 2

// in-kernel-web-server
// insmod server_kernel.ko num_threads=8
caturra@LAPTOP-RU7SB7FE:~/in_kernel_web_server$ wrk -t12 -c400 -d10s --latency "http://127.0.0.1:8848/"
Running 10s test @ http://127.0.0.1:8848/
  12 threads and 400 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency   583.60us    2.09ms  57.67ms   95.59%
    Req/Sec   151.83k    41.89k  214.34k    54.02%
  Latency Distribution
     50%  131.00us
     75%  257.00us
     90%    0.85ms
     99%    9.94ms
  18259600 requests in 10.10s, 0.88GB read
Requests/sec: 1807901.24
Transfer/sec:     89.66MB

// user-space-web-server
// ./server_user 8
caturra@LAPTOP-RU7SB7FE:~/in_kernel_web_server$ wrk -t12 -c400 -d10s --latency "http://127.0.0.1:8848/"
Running 10s test @ http://127.0.0.1:8848/
  12 threads and 400 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency   567.92us    2.46ms 110.91ms   97.24%
    Req/Sec   120.37k    31.57k  207.68k    66.75%
  Latency Distribution
     50%  211.00us
     75%  337.00us
     90%    0.90ms
     99%    8.08ms
  14282576 requests in 10.05s, 708.29MB read
Requests/sec: 1421451.12
Transfer/sec:     70.49MB

目前测试有点简单,但可以看出点东西。吞吐量仅论小报文 QPS 的话,在 CPU 可能打满[5]且单事务(因为足够简单从而使得)syscall 占比较高的情况下,内核态确实有优势;但是一个问题是报文体量太小了,这里看不出字节意义上的吞吐量;另外一个问题是最大延迟的抖动较大,这里数据单一看不出来,多组数据有±50% 的差距。还需要做个脚本来动态调整负载(用图说话),以及进一步 profile[6]

这种场景的初步结论是:内核态的吞吐性能约为用户态的 1.27 倍[7]

[5] 是否达到 CPU 瓶颈要用 profile 确认,但是……
[6] 但是目前环境不能定位到 kthread(为什么 (#゚д゚)?),因此 perf 抓不到,待解决。
[7] 用户态可设置非零整数以使用零拷贝特性做对比,但是在这种高频小报文测试中没有收益。

Copy avoidance is not a free lunch. As implemented, with page pinning, it replaces per byte copy cost with page accounting and completion notification overhead. As a result, MSG_ZEROCOPY is generally only effective at writes over around 10 KB.

MSG_ZEROCOPY – Kernel subsystem documentation (Networking)

简单剖析

in-kernel-web-server-user-flamegraph.svg 用户态 server 的火焰图

先看用户态的剖析凑合一下(唉内核态 kthread 那边没想到好办法处理)。上面是 svg 格式的火焰图,右键打开大图可以进行交互(放大、搜索、查看百分比耗时)。此前可以通过 top 判定 CPU 早就拉满了,因此直接 on-CPU 分析:

  1. 耗时分布。epoll_ctl 占用 12.76% 总时间,epoll_wait 占用 5.1%,recvmsg 占用 12.33%,sendmsg 占用 67%。前面这些加起来约为 97%,其它的就是额外开销。
  2. sys 瓶颈还是 user 瓶颈?很显然这里 server_thread 已经被系统调用挤满了,是 sys 瓶颈。
  3. sendmsg 的耗时大头就是最上方的平顶峰 _raw_spin_unlock_irqrestore,来自 34.7% 的软中断,占了 10% 的总耗时,但这些是用户态和内核态共有的 tcp_sendmsg 流程(内核态 server 那边是 kernel_sendmsg->sock_sendmsg->(sock.ops.sendmsg=tcp_sendmsg)),不能说明问题。
  4. 陷入内核态的开销。这一块看着就头大。 epoll_ctl 的开销是 12.76% - 7.75% = 5.01%(__GI_epoll_ctl - do_epoll_ctl),recvmsg 的开销是 12.33% - 7.22% = 5.11%(__libc_recvmsg - tcp_recvmsg),sendmsg 的开销是 67.04% - 61.16% = 5.88%(__libc_sendmsg - tcp_sendmsg),其它小头没看,前面几个总计开销 16%。
  5. 用户态的 copy 开销。在火焰图右上角搜索 copy 的字眼:epoll_ctl 占用了 1.16% 的 event 拷贝耗时,recvmsgsendmsg 各占用了 1.4% 的 msghdr 拷贝耗时。占了大约 4% 的总时间。

基本上开销的差距是稳定在 16% 以上(copy 开销是包含在陷入开销范围内),有时间的话找找剩下的差距在哪里。


UPDATE. 过了几天,突然想到虽然 perf 离奇的抓不到 kthread,但是可以考虑 eBPF 啊。面对一个 tid 失踪的强劲问题,bpftrace 搭配 filter 能否搞定?绝对可以。轻易可以。吔![8]

// 这里 tid 要看 dmesg 输出的日志手动修改,不然真找不到了。怀疑不仅是我的环境 CONFIG 有问题,perf 构建也可能有问题
$ bpftrace -e 'profile:hz:9999 / tid>=287 && tid<=294 / { @[kstack] = count(); }' > trace.data

[8] 话虽如此,bpftrace 采集的栈信息是携带偏移量(且可视化脚本不移除)的,这块需要自己处理。

in-kernel-web-server-kernel-flamegraph.svg 内核态 server 的火焰图

粗看其实和用户态差不多,山峰被压扁就是因为去掉了系统调用的陷入开销;由于内核态 epoll 实现上保留了多余的 fd 中间层,至少还有 2.2% 的不必要开销(搜索 fget 字眼)。

其它的细节就不继续分析了,剩下的吞吐量测试脚本也没啥安排,先放着。火焰图贴出来供感兴趣的看官们参考。(可交互大图链接


历史版本

这里已经不是正文了,用于记录基本的迭代过程。历史版本都写得挺粗糙的(当前版本也是),主要是提供一个改进的思路,也可以作为 review 练手,数下我写了多少个 bug (´_ゝ`)

版本 1:single kthread

太短不看
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/kthread.h>

#define SERVER_PORT 8848
#define NO_FAIL(reason, err_str, finally) \
 if(ret < 0) {err_str = reason; goto finally;}

static struct socket *server_socket = NULL;
static struct task_struct *thread_st;

static int create_server_socket(void) {
    struct sockaddr_in server_addr;
    int ret;
    int optval = 1;
    sockptr_t koptval = KERNEL_SOCKPTR(&optval);
    char *err_msg;

    ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &server_socket);
    NO_FAIL("Failed to create socket\n", err_msg, err_routine);

    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);

    ret = sock_setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, koptval, sizeof(optval));
    NO_FAIL("Failed to set SO_REUSEADDR\n", err_msg, err_routine);

    ret = kernel_bind(server_socket, (struct sockaddr *)&server_addr, sizeof(server_addr));
    NO_FAIL("Failed to bind socket\n", err_msg, err_routine);

    ret = kernel_listen(server_socket, 5);
    NO_FAIL("Failed to listen on socket\n", err_msg, err_routine);

    return 0;

err_routine:
    printk(KERN_ERR "%s", err_msg);
    if(server_socket) sock_release(server_socket);
    return ret;
}

static int server_thread(void *data) {
    int ret;
    struct socket *conn_socket = NULL;
    char *response =
        "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
    struct kvec vec;
    struct msghdr msg;
    char *err_msg;

    allow_signal(SIGKILL);

    while(!kthread_should_stop()) {
        ret = kernel_accept(server_socket, &conn_socket, 0);
        NO_FAIL("Failed to accept connection\n", err_msg, err_routine);

        memset(&msg, 0, sizeof(msg));
        vec.iov_base = response;
        vec.iov_len = strlen(response);

        ret = kernel_sendmsg(conn_socket, &msg, &vec, 1, vec.iov_len);
        NO_FAIL("Failed to send response\n", err_msg, err_routine);

        sock_release(conn_socket);
    }

    return 0;

err_routine:
    printk(KERN_ERR "%s", err_msg);
    if(conn_socket) sock_release(conn_socket);
    return ret;
}

static int __init simple_web_server_init(void) {
    int ret;

    ret = create_server_socket();
    if(ret < 0) {
        return ret;
    }

    thread_st = kthread_run(server_thread, NULL, "in_kernel_web_server");
    printk(KERN_INFO "worker thread id: %d\n", thread_st->pid);

    if(IS_ERR(thread_st)) {
        printk(KERN_ERR "Failed to create thread\n");
        return PTR_ERR(thread_st);
    }

    printk(KERN_INFO "Simple Web Server Initialized\n");

    return 0;
}

static void __exit simple_web_server_exit(void) {
    if(thread_st) {
        send_sig(SIGKILL, thread_st, 1);
        kthread_stop(thread_st);
    }

    if(server_socket) {
        sock_release(server_socket);
    }

    printk(KERN_INFO "Simple Web Server Exited\n");
}

module_init(simple_web_server_init);
module_exit(simple_web_server_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Caturra");
MODULE_DESCRIPTION("Simple In-Kernel Web Server");


该版本使用了单个 kthread 承受所有连接,是最简单的能跑的内核服务器。

很显然这是一个经典的同步服务器,浏览器输入 127.0.0.1:8848 就可以得到一串 Hello World。

版本 2:it-just-works epoll

太短不看
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/kthread.h>
#include <linux/eventpoll.h>
#include <linux/fdtable.h>
#include <linux/slab.h>
// Modified epoll header
#include <linux/fs.h>

#define SERVER_PORT 8848
#define NO_FAIL(reason, err_str, finally) \
 if(unlikely(ret < 0)) {err_str = reason; goto finally;}

static struct socket *server_socket = NULL;
static struct task_struct *thread_st;

static int create_server_socket(void) {
    struct sockaddr_in server_addr;
    int ret;
    int optval = 1;
    sockptr_t koptval = KERNEL_SOCKPTR(&optval);
    char *err_msg;

    ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &server_socket);
    NO_FAIL("Failed to create socket", err_msg, done);

    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);

    ret = sock_setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, koptval, sizeof(optval));
    NO_FAIL("Failed to set SO_REUSEADDR", err_msg, done);

    ret = kernel_bind(server_socket, (struct sockaddr *)&server_addr, sizeof(server_addr));
    NO_FAIL("Failed to bind socket", err_msg, done);

    ret = kernel_listen(server_socket, 1024);
    NO_FAIL("Failed to listen on socket", err_msg, done);

    return 0;

done:
    printk(KERN_ERR "%s\n", err_msg);
    if(server_socket) sock_release(server_socket);
    return ret;
}

static int server_thread(void *data) {
    int ret;
    struct socket *conn_socket = NULL;
    char *response =
        "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
    struct kvec vec;
    struct msghdr msg;
    char *err_msg;
    int epfd = -1;
    const int EVENTS = 1024;
    int nevents;
    struct epoll_event *events;
    struct epoll_event accept_event;
    struct file *file;
    int fd;

    events = kmalloc(EVENTS * sizeof(struct epoll_event), GFP_KERNEL);

    ret = do_epoll_create(0);
    NO_FAIL("Failed to create epoll instance", err_msg, done);
    epfd = ret;

    file = sock_alloc_file(server_socket, 0 /* NONBLOCK? */, NULL);
    if(IS_ERR(file)) {
        ret = -1;
        NO_FAIL("Failed to allocate file", err_msg, done);
    }
    ret = get_unused_fd_flags(0);
    NO_FAIL("Failed to get fd", err_msg, done);
    fd = ret;
    fd_install(fd, file);

    memset(&accept_event, 0, sizeof(struct epoll_event));
    memcpy(&accept_event.data, &server_socket, sizeof(struct socket *));
    accept_event.events = EPOLLIN;
    ret = do_epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &accept_event, false /* true for io_uring only */);
    NO_FAIL("Failed to register acceptor event", err_msg, done);

    allow_signal(SIGKILL);

    while(!kthread_should_stop()) {
        ret = do_epoll_wait(epfd, &events[0], EVENTS, NULL /* wait_ms == -1 */);
        NO_FAIL("Failed to wait epoll events", err_msg, done);
        nevents = ret;
        if(nevents > 0) {
            ret = kernel_accept(server_socket, &conn_socket, 0);
            NO_FAIL("Failed to accept connection", err_msg, done);

            memset(&msg, 0, sizeof(msg));
            vec.iov_base = response;
            vec.iov_len = strlen(response);

            ret = kernel_sendmsg(conn_socket, &msg, &vec, 1, vec.iov_len);
            NO_FAIL("Failed to send response", err_msg, done);

            sock_release(conn_socket);
            conn_socket = NULL;
        }
    }

done:
    if(ret < 0) printk(KERN_ERR "%s: %d\n", err_msg, ret);
    if(~epfd) close_fd(epfd);
    if(events) kfree(events);
    if(conn_socket) sock_release(conn_socket);
    thread_st = NULL;
    return ret;
}

static int __init simple_web_server_init(void) {
    int ret;

    ret = create_server_socket();
    if(ret < 0) {
        return ret;
    }

    thread_st = kthread_run(server_thread, NULL, "in_kernel_web_server");
    printk(KERN_INFO "worker thread id: %d\n", thread_st->pid);

    if(IS_ERR(thread_st)) {
        printk(KERN_ERR "Failed to create thread\n");
        return PTR_ERR(thread_st);
    }

    printk(KERN_INFO "Simple Web Server Initialized\n");

    return 0;
}

static void __exit simple_web_server_exit(void) {
    if(thread_st) {
        send_sig(SIGKILL, thread_st, 1);
        kthread_stop(thread_st);
    }

    if(server_socket) {
        sock_release(server_socket);
    }

    printk(KERN_INFO "Simple Web Server Exited\n");
}

module_init(simple_web_server_init);
module_exit(simple_web_server_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Caturra");
MODULE_DESCRIPTION("Simple In-Kernel Web Server");


这个版本只能说是可以使用内核态的 epoll,并且只关注了 accept 事件作为验证。

有些涉及到内核的改动请看进展一章。

版本 3:state machine

太短不看
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/kthread.h>
#include <linux/eventpoll.h>
#include <linux/fdtable.h>
#include <linux/slab.h>
// Modified epoll header
#include <linux/fs.h>

#define SERVER_PORT 8848
#define NO_FAIL(reason, err_str, finally) \
 if(unlikely(ret < 0)) {err_str = reason; goto finally;}

static struct socket *server_socket = NULL;
static struct task_struct *thread_st;

static int create_server_socket(void) {
    int ret;
    const char *err_msg;
    struct sockaddr_in server_addr;
    int optval = 1;
    sockptr_t koptval = KERNEL_SOCKPTR(&optval);

    ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &server_socket);
    NO_FAIL("Failed to create socket", err_msg, done);

    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);

    ret = sock_setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, koptval, sizeof(optval));
    NO_FAIL("sock_setsockopt(SO_REUSEADDR)", err_msg, done);

    ret = kernel_bind(server_socket, (struct sockaddr *)&server_addr, sizeof(server_addr));
    NO_FAIL("kernel_bind", err_msg, done);

    ret = kernel_listen(server_socket, 1024);
    NO_FAIL("kernel_listen", err_msg, done);

    return 0;

done:
    pr_err("%s", err_msg);
    if(server_socket) sock_release(server_socket);
    return ret;
}


static int make_fd_and_file(struct socket *sock) {
    struct file *file;
    int fd;
    file = sock_alloc_file(sock, 0 /* NONBLOCK? */, NULL);
    if(unlikely(IS_ERR(file))) return -1;
    fd = get_unused_fd_flags(0);
    if(unlikely(fd < 0)) return fd;
    fd_install(fd, file);
    return fd;
}


static int update_event(int epfd, int ep_ctl_flag, __poll_t ep_event_flag, /// Epoll.
                        struct socket *sockets[], int fd, struct socket *sock) { /// Sockets.
    int ret;
    struct epoll_event event;
    bool fd_is_ready = (ep_ctl_flag != EPOLL_CTL_ADD);

    if(!fd_is_ready) fd = make_fd_and_file(sock);
    if(unlikely(fd < 0)) {
        pr_warn("fd cannot allocate: %d\n", fd);
        return fd;
    }
    event.data = fd;
    event.events = ep_event_flag;
    ret = do_epoll_ctl(epfd, ep_ctl_flag, fd, &event, false /* true for io_uring only */);
    if(unlikely(ret < 0)) {
        pr_warn("do_epoll_ctl: %d\n", ret);
        return ret;
    }
    if(!fd_is_ready) sockets[fd] = sock;
    return fd;
}


static void dump_event(struct epoll_event *e) {
    bool has_epollin  = e->events & EPOLLIN;
    bool has_epollout = e->events & EPOLLOUT;
    bool has_epollhup = e->events & EPOLLHUP;
    bool has_epollerr = e->events & EPOLLERR;
    __u64 data = e->data;
    pr_info("dump: %d%d%d%d %llu\n", has_epollin, has_epollout, has_epollhup, has_epollerr, data);
}


static void event_loop(int epfd, struct epoll_event *events, const int nevents,
                       int server_fd, struct socket *sockets[],
                       char *drop_buffer, const size_t DROP_BUFFER) {
    int ret;

    __poll_t next_event;
    int client_fd;
    struct socket *client_socket;

    char *response =
        "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
    struct kvec vec;
    struct msghdr msg;

    for(struct epoll_event *e = &events[0]; e != &events[nevents]; e++) {
        // dump_event(e);
        if(e->data == server_fd) {
            kernel_accept(server_socket, &client_socket, 0);
            update_event(epfd, EPOLL_CTL_ADD, EPOLLIN | EPOLLHUP, sockets, -1, client_socket);
        } else {
            next_event = e->events;
            client_fd = e->data;
            client_socket = sockets[client_fd];
            if(e->events & EPOLLIN) {
                memset(&msg, 0, sizeof(msg));
                vec.iov_base = drop_buffer;
                vec.iov_len = DROP_BUFFER;
                ret = kernel_recvmsg(client_socket, &msg, &vec, 1, vec.iov_len, 0);
                // Fast check: Maybe a FIN packet and nothing is buffered (!EPOLLOUT).
                if(ret == 0 && e->events == EPOLLIN) {
                    e->events = EPOLLHUP;
                // May be an RST packet.
                } else if(unlikely(ret < 0)) {
                    pr_warn("kernel_recvmsg: %d\n", ret);
                    if(ret != -EINTR) e->events = EPOLLHUP;
                // Slower path, call (do_)epoll_ctl().
                } else {
                    next_event &= ~EPOLLIN;
                    next_event |= EPOLLOUT;
                }
            }
            if(e->events & EPOLLOUT) {
                memset(&msg, 0, sizeof(msg));
                vec.iov_base = response;
                vec.iov_len = strlen(response);

                ret = kernel_sendmsg(client_socket, &msg, &vec, 1, vec.iov_len);
                if(ret < 0) {
                    pr_warn("kernel_sendmsg: %d\n", ret);
                    if(ret != -EINTR) e->events = EPOLLHUP;
                } else {
                    next_event &= ~EPOLLOUT;
                    next_event |= EPOLLIN;
                }
            }
            if(e->events & EPOLLHUP && !(e->events & EPOLLIN)) {
                next_event = EPOLLHUP;
                ret = update_event(epfd, EPOLL_CTL_DEL, 0, sockets, client_fd, client_socket);
                if(unlikely(ret < 0)) pr_warn("update_event[HUP]: %d\n", ret);
                sock_release(client_socket);
            }
            if(likely(e->events != EPOLLHUP)) {
                ret = update_event(epfd, EPOLL_CTL_MOD, next_event,
                                    sockets, client_fd, client_socket);
                if(unlikely(ret < 0)) pr_warn("update_event[~HUP]: %d\n", ret);
            }
        }
    }
}

static int server_thread(void *data) {
    /// Control flows.

    int ret;
    const char *err_msg;

    /// Sockets.

    int server_fd;
    // Limited by fd size. 1024 is enough for test.
    // Usage: sockets[fd] = socket_ptr.
    const size_t SOCKETS = 1024;
    struct socket **sockets;

    /// Buffers.

    const size_t DROP_BUFFER = 1024;
    char *drop_buffer;

    /// Epoll.

    int epfd = -1;
    const size_t EVENTS = 1024;
    int nevents;
    struct epoll_event *events;

    sockets = kmalloc_array(SOCKETS, sizeof(struct socket*), GFP_KERNEL);
    events = kmalloc_array(EVENTS, sizeof(struct epoll_event), GFP_KERNEL);
    drop_buffer = kmalloc(DROP_BUFFER, GFP_KERNEL);

    // Debug only.
    (void)dump_event;

    allow_signal(SIGKILL);

    ret = do_epoll_create(0);
    NO_FAIL("do_epoll_create", err_msg, done);
    epfd = ret;

    ret = update_event(epfd, EPOLL_CTL_ADD, EPOLLIN, sockets, -1, server_socket);
    NO_FAIL("update_event", err_msg, done);
    server_fd = ret;

    while(!kthread_should_stop()) {
        ret = do_epoll_wait(epfd, &events[0], EVENTS, NULL /* wait_ms == -1 */);
        NO_FAIL("do_epoll_wait", err_msg, done);
        nevents = ret;
        event_loop(epfd, events, nevents,
                   server_fd, sockets,
                   drop_buffer, DROP_BUFFER);
    }

done:
    if(ret < 0) pr_err("%s: %d\n", err_msg, ret);
    if(~epfd) close_fd(epfd);
    if(events) kfree(events);
    if(drop_buffer) kfree(drop_buffer);
    if(sockets) kfree(sockets);
    thread_st = NULL;
    // TODO record max_fd and destroy.
    return ret;
}


static int __init simple_web_server_init(void) {
    int ret;

    ret = create_server_socket();
    if(ret < 0) {
        return ret;
    }

    thread_st = kthread_run(server_thread, NULL, "in_kernel_web_server");
    pr_info("worker thread id: %d\n", thread_st->pid);

    if(IS_ERR(thread_st)) {
        pr_err("Failed to create thread\n");
        return PTR_ERR(thread_st);
    }

    pr_info("Simple Web Server Initialized\n");

    return 0;
}


static void __exit simple_web_server_exit(void) {
    if(thread_st) {
        send_sig(SIGKILL, thread_st, 1);
        kthread_stop(thread_st);
    }

    if(server_socket) {
        sock_release(server_socket);
    }

    pr_info("Simple Web Server Exited\n");
}


module_init(simple_web_server_init);
module_exit(simple_web_server_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Caturra");
MODULE_DESCRIPTION("Simple In-Kernel Web Server");


新版本添加了长连接使用到的状态机,状态转移的意思很简单:先关注读,读到所需的内容后[4]关注写;对于转移分为就地转移和下一个状态转移,主要是处理中断、异常报文和避免不必要的 ctl。可以看出这种代码很容易适配出用户态对应的版本(就是C的表达能力确实逼人写啰嗦的代码)。在迁移和对比性能之前,剩余的工作还有多线程,打算直接 REUSE_PORT 拆出来就好了。

[4] 这里的读操作还没有 parser 支持,只要读到一次(不管实际内容是否完整或者多次)都直接吞掉,并认为确实读了一次完整的内容。

References

Read/write files within a Linux kernel module – Stack Overflow
Getting file descriptors and details within kernel space without open() – Stack Overflow

This post is licensed under CC BY 4.0 by the author.
Contents