网络编程-包过滤防火墙简单实现

一、netfilter框架

这次实验使用netfilter框架,参考《网络编程》相关知识以及样例代码。

Netfilter是 Linux 内核中的一个框架,它为以定制处理器形式实施的各种网络相关操作提供了灵活性。Netfilter提供数据包过滤、网络地址翻译和端口翻译的各种选项。

检查点

在netfilter中,对于IPv4协议栈上传输过程,设置了5个检查点,PREROUTINGLOCAL-INLOCAL-OUTFORWARDPOSTROUTING

  • PREROUTING:二层收包结束后,根据注册协议和回调函数分发数据包,对于IPv4的数据包,会送入三层进行合法性检测处理,并经过PREROUTING检查点。该检查点的处理过程在报文做路由前执行。
  • LOCAL-IN:在经过PREROUTING检查点,会根据报文判断转发或输入本地,若输入本地,则需要对分片进行检查重组,之后经过LOCAL-IN检查点。该检查点的处理过程在流入本地的报文做路由前执行。
  • LOCAL-OUT:从本机发出的数据包,在查询路由成功后,经过LOCAL-OUT检查点,该检查点的处理过程在本地报文做流出路由前执行。
  • FORWARD:在经过PREROUTING检查点,会根据报文判断转发或输入本地,若需要转发,则经过FORWARD检查点。
  • POSTROUTING:转发的数据包或者是本地输出的数据包在最后输出之前,都需要经过POSTROUTING检查点。

本次实验主要只用LOCAL-INLOCAL-OUT检查点,实现一个简单的包过滤防火墙。

HOOK

包过滤防火墙实现对出入站的ip地址,端口以及ping进行检测过滤,实现方式是在对应检查点挂接回调函数,即在检查检点注册hook,函数为nf_register_hook()

二、设计思路

设计思路

包过滤防火墙主要针对入站以及出站流量,选择LOCAL-INLOCAL-OUT检查点对本机相关的流量进行检测过滤,设计出入站过滤规则开关,控制是否开启过滤规则,建立出站和入站的过滤表,对过滤ip以及端口进行管理。

数据结构

struct filter_status {
    /* 0:放行,1:过滤 */
    int ping_status;    //禁ping
    int ip_status;      //禁ip
    int port_status;    //禁port
    int debug_level;    //0不打印信息,1打印信息
    struct in_addr ban_ip_in[IPSIZE];   //禁用入站ip数组
    struct in_addr ban_ip_out[IPSIZE];   //禁用出站ip数组
    int ban_port_in[PORTSIZE]; //禁用入站端口
    int ban_port_out[PORTSIZE]; //禁用出站端口
};

filter_status过滤器,是防火墙的规则集,防火墙根据其中的状态信息以及数组元素执行规则。

  • ping_statusip_statusport_status,控制ping过滤,ip地址过滤以及端口过滤是否开启。
  • debug_level控制是否打印hook调用数量。
  • ban_ip_inban_ip_out分别用于记录入站和出站ip过滤的ip地址。
  • ban_port_inban_port_out分别用于记录入站和出站端口过滤的端口号。
struct val_status {
    unsigned int parameter;  //协议编号, ip地址, 端口号
    unsigned int rule;   //a|r   0放行 1过滤
    unsigned int ip_port_switch;  //用于ip和port是否同时修改出站和入栈规则   1:修改入站规则,2:修改出站规则,3:都修改,4:开启规则,5:关闭规则
};

val_status是用于用户态的控制程序向内核防火墙传输数据的数据结构

  • parameter指定用户操作的类型,如具体协议编号,ip地址或端口号。
  • rule规定防火墙规则,0放行,1丢弃。
  • ip_port_switch对ip或端口设置规则时,指定设置出站或入站规则。

常量定义

#define CMD_MIN		0x6000
#define CMD_DEBUG		0x6001
#define CMD_PROTOCOL	0x6002
#define CMD_IP			0x6003
#define CMD_PORT		0x6004
#define CMD_MAX		0x6100
//定义协议标识常量
#define DEBUG_SYN   99
#define ICMP_SYN    100
#define IP_SYN      300
#define PORT_SYN    301
//定义ip过滤数组以及端口过滤数组的容量
#define IPSIZE 100
#define PORTSIZE 65535

定义协议标识规范数据传输,100-299为协议编号,300为ip标识,301为端口标识。

三、源代码

//myfw.h

#define CMD_MIN		0x6000
#define CMD_DEBUG		0x6001
#define CMD_PROTOCOL	0x6002
#define CMD_IP			0x6003
#define CMD_PORT		0x6004
#define CMD_MAX		0x6100
//定义协议标识常量
#define DEBUG_SYN   99
#define ICMP_SYN    100
#define IP_SYN      300
#define PORT_SYN    301
#define IPSIZE 100
#define PORTSIZE 65535
struct val_status {
    unsigned int parameter;  //协议编号, ip地址, 端口号
    unsigned int rule;   //a|r   0通过 1过滤
    unsigned int ip_port_switch;  //用于ip和port是否同时修改出站和入栈规则   1:修改入站规则,2:修改出站规则,3:都修改,4:开启规则,5:关闭规则
};
struct filter_status {
    /* 0:通过,1:过滤 */
    int ping_status;    //禁ping
    int ip_status;      //禁ip
    int port_status;    //禁port
    int debug_level;    //0不打印信息,1打印信息
    struct in_addr ban_ip_in[IPSIZE];   //禁用入站ip数组
    struct in_addr ban_ip_out[IPSIZE];   //禁用出站ip数组
    int ban_port_in[PORTSIZE]; //禁用入站端口
    int ban_port_out[PORTSIZE]; //禁用出站端口
};
int find_ip(int ip, struct in_addr* ip_list)
{
    int i;
    for(i = 0; i < IPSIZE; i++) {
        if(ip == ip_list[i].s_addr) {
            return i;
        }
    }
    return -1;
}
int find_port(short port, int *port_list)
{
    int i;
    for(i = 0; i < PORTSIZE; i++) {
        if(port == port_list[i]) {
            return i;
        }
    }
    return -1;
}
//myfw.c

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/skbuff.h>
#include <net/tcp.h>
#include <linux/netdevice.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/string.h>

#include "myfw.h"

static struct nf_hook_ops nfhoLocalIn;
static struct nf_hook_ops nfhoLocalOut;
static struct nf_hook_ops nfhoPreRouting;
static struct nf_hook_ops nfhoForward;
static struct nf_hook_ops nfhoPostRouting;
static struct nf_sockopt_ops nfhoSockopt;
static struct filter_status filter = {0, 0, 0, 0, {0}, {0}, {0}, {0}};	//过滤器
static struct val_status val;	//信号接收结构体,用于接收控制程序发出的控制信号

static int nfcount = 0;

void debugInfo(char * msg)
{
	if (filter.debug_level) {
		nfcount++;
		printk("%s, nfcount: %d\n", msg, nfcount);
	}
}

unsigned int hookLocalIn(void *priv,
	struct sk_buff *skb,
	const struct nf_hook_state *state)
{
	unsigned rc = NF_ACCEPT;

	struct iphdr *iph = ip_hdr(skb);
	struct tcphdr *tcph = NULL;
	struct udphdr *ucph = NULL;
	int i;
	//icmp过滤
	if (iph->protocol == IPPROTO_ICMP && filter.ping_status)
		rc = NF_DROP; 
    //ip入站过滤
	if (filter.ip_status) {
		for(i = 0; i < IPSIZE; i++) {	//遍历入站ip过滤数组
			if(filter.ban_ip_in[i].s_addr == iph->saddr) {
				rc = NF_DROP;
				break;
			}
		}
	}
    //端口入站过滤
	if(filter.port_status) {
		if(iph->protocol == IPPROTO_TCP) {	//判定为tcp数据包
			tcph = tcp_hdr(skb);			//获取tcp头
			for(i = 0; i < PORTSIZE; i++) {	//遍历入站端口过滤数组
				if(filter.ban_port_in[i] == ntohs(tcph->dest)) {
					rc = NF_DROP;
					break;
				}
			}
		}
		if(iph->protocol == IPPROTO_UDP) {	//判定为udp数据包
			ucph = udp_hdr(skb);			//获取udp头
			for(i = 0; i < PORTSIZE; i++) {	//遍历入站端口过滤数组
				if(filter.ban_port_in[i] == ntohs(ucph->dest)) {
					rc = NF_DROP;
					break;
				}
			}
		}
	}
	debugInfo("hookLocalIn");

	return rc;
}

unsigned int hookLocalOut(void *priv,
	struct sk_buff *skb,
	const struct nf_hook_state *state)
{
	unsigned rc = NF_ACCEPT;

	struct iphdr *iph = ip_hdr(skb);
	struct tcphdr *tcph = NULL;
	struct udphdr *ucph = NULL;
	int i;

	if(filter.ip_status) {
		for(i = 0; i < IPSIZE; i++) {	//遍历出站ip过滤数组
			if(filter.ban_ip_out[i].s_addr == iph->daddr) {
				rc = NF_DROP;
				break;
			}
		}
	}
	if(filter.port_status) {
		if(iph->protocol == IPPROTO_TCP) {	//判定为tcp数据包
			tcph = tcp_hdr(skb);			//获取tcp头
			for(i = 0; i < PORTSIZE; i++) {	//遍历入站端口过滤数组
				if(filter.ban_port_out[i] == ntohs(tcph->source)) {
					rc = NF_DROP;
					break;
				}
			}
		}
		if(iph->protocol == IPPROTO_UDP) {	//判定为udp数据包
			ucph = udp_hdr(skb);			//获取udp头
			for(i = 0; i < PORTSIZE; i++) {	//遍历入站端口过滤数组
				if(filter.ban_port_out[i] == ntohs(ucph->source)) {
					rc = NF_DROP;
					break;
				}
			}
		}
	}
	debugInfo("hookLocalOut");

	return rc;
}

unsigned int hookPreRouting(void *priv,
	struct sk_buff *skb,
	const struct nf_hook_state *state)
{
	debugInfo("hookPreRouting");

	return NF_ACCEPT;
}

unsigned int hookPostRouting(void *priv,
	struct sk_buff *skb,
	const struct nf_hook_state *state)
{
	debugInfo("hookPostRouting");

	return NF_ACCEPT;
}

unsigned int hookForward(void *priv,
	struct sk_buff *skb,
	const struct nf_hook_state *state)
{
	debugInfo("hookForwarding");

	return NF_ACCEPT;
}

int hookSockoptSet(struct sock *sock,
	int cmd,
	sockptr_t user,
	unsigned int len)
{
	int ret;
	int parameter;
	char function[10] = {0};
	char rule[2] = {0};
	char *p;
	debugInfo("hookSockoptSet");
	
	switch (cmd) {
	case CMD_DEBUG:	//debug开启或关闭
		ret = copy_from_user(&val, user.user, sizeof(val));
		filter.debug_level = val.rule;
		printk("set debug level to %d", val.rule);
		break;
	case CMD_PROTOCOL:	//协议规则,目前只有icmp ping
		ret = copy_from_user(&val, user.user, sizeof(val));
		filter.ping_status = val.rule;
		printk("ICMP:%d rule:%d", val.parameter, val.rule);
		break;
	case CMD_IP:	//ip规则,出入规则
		ret = copy_from_user(&val, user.user, sizeof(val));
		if(val.ip_port_switch != 4 && val.ip_port_switch != 5 && (val.parameter < 1 || val.parameter > 4294967295)) break; // 合法ip范围 0.0.0.1 ~ 255.255.255.255
		if(val.ip_port_switch == 1) {		//入站规则
			if(val.rule == 1) {				//添加ip
				int flag = find_ip(val.parameter, filter.ban_ip_in);
				if(flag == -1) {
					int i;
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_in[i].s_addr == 0) {
							printk("add input ip %u\n", val.parameter);
							filter.ban_ip_in[i].s_addr = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {		//删除ip
				int flag = find_ip(val.parameter, filter.ban_ip_in);
				printk("delete input ip %u\n", val.parameter);
				if(flag != -1) {
					filter.ban_ip_in[flag].s_addr = 0;
				}
			}
		}
		else if(val.ip_port_switch == 2) {	//出站规则
			if(val.rule == 1) {				//添加ip
				int flag = find_ip(val.parameter, filter.ban_ip_out);
				if(flag == -1) {
					int i;
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_out[i].s_addr == 0) {
							printk("add output ip %u\n", &val.parameter);
							filter.ban_ip_out[i].s_addr = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {		//删除ip
				int flag = find_ip(val.parameter, filter.ban_ip_out);
				printk("delete output ip %u\n", val.parameter);
				if(flag != -1) {
					filter.ban_ip_out[flag].s_addr = 0;
				}
			}
		}
		else if(val.ip_port_switch == 3) {	//出入规则
			if(val.rule == 1) {				//添加ip
				int flag = find_ip(val.parameter, filter.ban_ip_in);
				if(flag == -1) {
					int i;
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_in[i].s_addr == 0) {
							printk("add input ip %u\n", val.parameter);
							filter.ban_ip_in[i].s_addr = val.parameter;
							break;
						}
					}
				}
				flag = find_ip(val.parameter, filter.ban_ip_out);
				if(flag == -1) {
					int i;
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_out[i].s_addr == 0) {
							printk("add output ip %u\n", val.parameter);
							filter.ban_ip_out[i].s_addr = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {		//删除ip
				int flag = find_ip(val.parameter, filter.ban_ip_in);
					printk("delete input ip %u\n", val.parameter);
					if(flag != -1) {
						filter.ban_ip_in[flag].s_addr = 0;
					}
				flag = find_ip(val.parameter, filter.ban_ip_out);
				printk("delete output ip %u\n", val.parameter);
				if(flag != -1) {
					filter.ban_ip_out[flag].s_addr = 0;
				}
			}
		}
		else if(val.ip_port_switch == 4) {	//开启ip过滤
			filter.ip_status = 1;
		}
		else if(val.ip_port_switch == 5) {	//关闭ip过滤
			filter.ip_status = 0;
		}
		break;
	case CMD_PORT:
		ret = copy_from_user(&val, user.user, sizeof(val));
		if(val.ip_port_switch != 4 && val.ip_port_switch != 5 && (val.parameter < 1 || val.parameter > 65535)) break; // 合法端口范围 1 ~ 65535
		if(val.ip_port_switch == 1) {	//入站规则
			if(val.rule == 1) {			//添加port
				int flag = find_port(val.parameter, filter.ban_port_in);
				if(flag == -1) {
					int i;
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_in[i] == -1) {
							printk("add input port %d\n", val.parameter);
							filter.ban_port_in[i] = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {	//删除port
				int flag = find_port(val.parameter, filter.ban_port_in);
				if(flag != -1) {
					filter.ban_port_in[flag] = -1;
				}
			}
		}
		if(val.ip_port_switch == 2) {	//出站规则
			if(val.rule == 1) {			//添加port
				int flag = find_port(val.parameter, filter.ban_port_out);
				if(flag == -1) {
					int i;
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_out[i] == -1) {
							filter.ban_port_out[i] = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {	//删除port
				int flag = find_port(val.parameter, filter.ban_port_out);
				if(flag != -1) {
					filter.ban_port_out[flag] = -1;
				}
			}
		}
		if(val.ip_port_switch == 3) {	//出入规则
			if(val.rule == 1) {			//添加port
				int flag = find_port(val.parameter, filter.ban_port_in);
				if(flag == -1) {
					int i;
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_in[i] == -1) {
							filter.ban_port_in[i] = val.parameter;
							break;
						}
					}
				}
				flag = find_port(val.parameter, filter.ban_port_out);
				if(flag == -1) {
					int i;
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_out[i] == -1) {
							filter.ban_port_out[i] = val.parameter;
							break;
						}
					}
				}
			}
			else if(val.rule == 0) {	//删除port
				int flag = find_port(val.parameter, filter.ban_port_in);
				if(flag != -1) {
					filter.ban_port_in[flag] = -1;
				}
				flag = find_port(val.parameter, filter.ban_port_out);
				if(flag != -1) {
					filter.ban_port_out[flag] = -1;
				}
			}
		}
		if(val.ip_port_switch == 4) {	//开启port过滤
			filter.port_status = 1;
		}
		if(val.ip_port_switch == 5) {	//关闭port过滤
			filter.port_status = 0;
		}
	}
	if (ret != 0) {
		printk("copy_from_user error");
		ret = -EINVAL;
	}
	return ret;
}

int hookSockoptGet(struct sock *sock,
	int cmd,
	void __user *user,
	int * len)
{
	int ret;
	debugInfo("hookSockoptGet");
	int i;
	switch (cmd) {
	case CMD_DEBUG:
		ret = copy_to_user(user, &filter, sizeof(filter));
		break;
	case CMD_PROTOCOL:
		ret = copy_to_user(user, &filter, sizeof(filter));
		break;
	case CMD_IP:
		printk("input ip list:\n");
		for(i = 0; i < IPSIZE; i++) {
			printk("%u\n", filter.ban_ip_in[i].s_addr);
		}
		printk("output ip list:\n");
		for(i = 0; i < IPSIZE; i++) {
			printk("%u\n", filter.ban_ip_out[i].s_addr);
		}
		ret = copy_to_user(user, &filter, sizeof(filter));
		break;
	case CMD_PORT:
		printk("input port list:\n");
		for(i = 0; i < PORTSIZE; i++) {
			if(filter.ban_port_in[i] != -1)
				printk("%d\n", filter.ban_port_in[i]);
		}
		printk("output port list:\n");
		for(i = 0; i < PORTSIZE; i++) {
			if(filter.ban_port_out[i] != -1)
				printk("%d\n", filter.ban_port_out[i]);
		}
		ret = copy_to_user(user, &filter, sizeof(filter));
		break;
	}
	if (ret != 0) {
		ret = -EINVAL;
		debugInfo("copy_to_user error");
	}
	return ret;
}

int init_module()
{
	int i;
	//初始化过滤器
	for(i = 0; i < IPSIZE; i++) {
		filter.ban_ip_in[i].s_addr = 0;
		filter.ban_ip_out[i].s_addr = 0;
	}
	for(i = 0; i < PORTSIZE; i++) {
		filter.ban_port_in[i] = -1;
		filter.ban_port_out[i] = -1;
	}
	//初始化hookLocalIn
	nfhoLocalIn.hook = hookLocalIn;
	nfhoLocalIn.hooknum = NF_INET_LOCAL_IN;
	nfhoLocalIn.pf = PF_INET;
	nfhoLocalIn.priority = NF_IP_PRI_FIRST;
	nf_register_net_hook(&init_net, &nfhoLocalIn);
	//初始化hookLocalOut
	nfhoLocalOut.hook = hookLocalOut;
	nfhoLocalOut.hooknum = NF_INET_LOCAL_OUT; 
	nfhoLocalOut.pf = PF_INET;
	nfhoLocalOut.priority = NF_IP_PRI_FIRST;
	nf_register_net_hook(&init_net, &nfhoLocalOut);
	//初始化hookPreRouting
	nfhoPreRouting.hook = hookPreRouting;
	nfhoPreRouting.hooknum = NF_INET_PRE_ROUTING; 
	nfhoPreRouting.pf = PF_INET;
	nfhoPreRouting.priority = NF_IP_PRI_FIRST;
	nf_register_net_hook(&init_net, &nfhoPreRouting);
	//初始化hookForward
	nfhoForward.hook = hookForward;
	nfhoForward.hooknum = NF_INET_FORWARD; 
	nfhoForward.pf = PF_INET;
	nfhoForward.priority = NF_IP_PRI_FIRST;
	nf_register_net_hook(&init_net, &nfhoForward);
	//初始化hookPostRouting
	nfhoPostRouting.hook = hookPostRouting;
	nfhoPostRouting.hooknum = NF_INET_POST_ROUTING; 
	nfhoPostRouting.pf = PF_INET;
	nfhoPostRouting.priority = NF_IP_PRI_FIRST;
	nf_register_net_hook(&init_net, &nfhoPostRouting);
	//初始化hookSockoptSet和//初始化hookSockoptGet
	nfhoSockopt.pf = PF_INET;
	nfhoSockopt.set_optmin = CMD_MIN;
	nfhoSockopt.set_optmax = CMD_MAX;
	nfhoSockopt.set = hookSockoptSet;
	nfhoSockopt.get_optmin = CMD_MIN;
	nfhoSockopt.get_optmax = CMD_MAX;
	nfhoSockopt.get = hookSockoptGet;
	nf_register_sockopt(&nfhoSockopt);

	printk("myfw started\n");
	return 0;
}
void cleanup_module()
{
	nf_unregister_net_hook(&init_net, &nfhoLocalIn);
	nf_unregister_net_hook(&init_net, &nfhoLocalOut);
	nf_unregister_net_hook(&init_net, &nfhoPreRouting);
	nf_unregister_net_hook(&init_net, &nfhoForward);
	nf_unregister_net_hook(&init_net, &nfhoPostRouting);
	nf_unregister_sockopt(&nfhoSockopt);
	printk("myfw stopped\n");
}
MODULE_LICENSE("GPL");

//myfwctl.c

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <pthread.h>
#include <signal.h>
#include <errno.h>  
#include "myfw.h"

static struct val_status val;
static struct filter_status filter = {0, 0, 0, {0}, {0}, {0}, {0}};

void printError(char * msg)
{
	printf("%s error %d: %s\n", msg, errno, strerror(errno));
}

void printSuccess(char * msg) 
{
	printf("%s success\n",msg);
}

void usage(char * program)
{
	printf("%s <set> <debug debug_level>|<ip input|output|all|open|close>[ip_addr a|r]|<port input|output|all|open|close>[port_no a|r]|<protocol protocol_name a|r>\n", program);
	printf("%s <get> <debug>|<ip>|<port>|<protocol>\n", program);
}

int set(int argc, char * argv[], int sockfd) {
	int ret = -1;
	if (argc > 3) {
		char * obj = argv[2];
		int val_len = sizeof(val);
		int cmd = 0;
		int syn;
		if (!strcmp(obj, "debug")) {
			cmd = CMD_DEBUG;
			val.parameter = DEBUG_SYN;
			val.rule = atoi(argv[3]);
		}
		else if (!strcmp(obj, "protocol")) {
			cmd = CMD_PROTOCOL;
			if (!strcmp(argv[3], "icmp")) {	//set protocol icmp a
				val.parameter = ICMP_SYN;
				if (!strcmp(argv[4], "a")) {
					val.rule = 0;
				}
				else if (!strcmp(argv[4], "r")) {
					val.rule = 1;
				}
			}
		}
		else if (!strcmp(obj, "ip")) {
			cmd = CMD_IP;
			struct in_addr ip;
			if (!strcmp(argv[3], "input")) {		//设定入站ip规则
				val.ip_port_switch = 1;
				inet_aton(argv[4], &ip.s_addr);
				val.parameter = ip.s_addr;	//网络序传输
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if (!strcmp(argv[3], "output")) {	//设定出站ip规则
				val.ip_port_switch = 2;
				inet_aton(argv[4], &ip.s_addr);
				val.parameter = ip.s_addr;	//网络序传输
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if (!strcmp(argv[3], "all")) {		//设定出入ip规则
				val.ip_port_switch = 3;				
				inet_aton(argv[4], &ip.s_addr);
				val.parameter = ip.s_addr;	//网络序传输
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if(!strcmp(argv[3], "open")) {		//开启ip过滤
				val.ip_port_switch = 4;
			}
			else if(!strcmp(argv[3], "close")) {	//关闭ip过滤
				val.ip_port_switch = 5;
			}
		}
		else if(!strcmp(obj, "port")) {		
			cmd = CMD_PORT;
			if (!strcmp(argv[3], "input")) {		//设定入站端口规则
				val.ip_port_switch = 1;
				val.parameter = atoi(argv[4]);
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if (!strcmp(argv[3], "output")) {	//设定出站端口规则
				val.ip_port_switch = 2;
				val.parameter = atoi(argv[4]);
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if (!strcmp(argv[3], "all")) {		//设定出入端口规则
				val.ip_port_switch = 3;				
				val.parameter = atoi(argv[4]);
				if (!strcmp(argv[5], "a")) {		//放行
					val.rule = 0;
				}
				else if (!strcmp(argv[5], "r")) {	//过滤
					val.rule = 1;
				}
			}
			else if(!strcmp(argv[3], "open")) {		//开启端口过滤
				val.ip_port_switch = 4;
			}
			else if(!strcmp(argv[3], "close")) {	//关闭端口过滤
				val.ip_port_switch = 5;
			}
		}
		if (cmd) {
			if (setsockopt(sockfd, IPPROTO_IP, cmd, &val, val_len)) {
				printError("setsockopt()");
			}
			else {
				// printSuccess("setsockopt()");
				ret = 0;
			}
		}
	}
	else {
		usage(argv[0]);
	}
	return ret;
}

int get(int argc, char * argv[], int sockfd) {
	int ret = -1;
	if (argc > 2) {
		int cmd = 0;
		socklen_t filter_len = sizeof(filter);
		char * obj = argv[2];
		if (!strcmp(obj, "debug")) {
			cmd = CMD_DEBUG;
		}
		else if (!strcmp(obj, "protocol")) {
			cmd = CMD_PROTOCOL;
		}
		else if (!strcmp(obj, "ip")) {
			cmd = CMD_IP;
		}
		else if (!strcmp(obj, "port")) {
			cmd = CMD_PORT;
		}
		if (cmd) {
			if (getsockopt(sockfd, IPPROTO_IP, cmd, &filter, &filter_len)) {
				printError("getsockopt");
			}
			else {
				switch (cmd) {
				int i;
				case CMD_DEBUG:
					printf("debug level=%d\n", filter.debug_level);
					break;
				case CMD_PROTOCOL:
					if(filter.ping_status == 0)
						printf("allow ping\n");
					if(filter.ping_status == 1)
						printf("ban ping\n");
					break;
				case CMD_IP:
					printf("input ip list:\n");
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_in[i].s_addr != 0) {
							unsigned char bytes[4];
							bytes[0] = (filter.ban_ip_in[i].s_addr >> 24) & 0xFF;
							bytes[1] = (filter.ban_ip_in[i].s_addr >> 16) & 0xFF;
							bytes[2] = (filter.ban_ip_in[i].s_addr >> 8) & 0xFF;
							bytes[3] = (filter.ban_ip_in[i].s_addr >> 0) & 0xFF;
							printf("%d.%d.%d.%d\n", bytes[3], bytes[2], bytes[1], bytes[0]);	//输出点分十进制ip地址
						}
					}
					printf("\noutput ip list:\n");
					for(i = 0; i < IPSIZE; i++) {
						if(filter.ban_ip_out[i].s_addr != 0) {
							unsigned char bytes[4];
							bytes[0] = (filter.ban_ip_out[i].s_addr >> 24) & 0xFF;
							bytes[1] = (filter.ban_ip_out[i].s_addr >> 16) & 0xFF;
							bytes[2] = (filter.ban_ip_out[i].s_addr >> 8) & 0xFF;
							bytes[3] = (filter.ban_ip_out[i].s_addr >> 0) & 0xFF;
							printf("%d.%d.%d.%d\n", bytes[3], bytes[2], bytes[1], bytes[0]);	//输出点分十进制ip地址
						}
					}
					if(filter.ip_status == 0)
						printf("is open ip filter:	NO\n");
					if(filter.ip_status == 1)
						printf("is open ip filter:	YES\n");
					break;
				case CMD_PORT:
					printf("input port list\n");
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_in[i] != -1) {
							printf("%d\n", filter.ban_port_in[i]);
						}
					}
					printf("\noutput port list:\n");
					for(i = 0; i < PORTSIZE; i++) {
						if(filter.ban_port_out[i] != -1) {
							printf("%d\n", filter.ban_port_out[i]);
						}
					}
					if(filter.port_status == 0)
						printf("is open port filter:	NO\n");
					if(filter.port_status == 1)
						printf("is open port filter:	YES\n");
				}
			}
		}
	}
	return ret;
}

int main(int argc, char * argv[])
{
	int ret = -1;
	if (argc < 3) {
		usage(argv[0]);
	}
	else {
		int sockfd;
		if ((sockfd = socket(AF_INET, SOCK_RAW, IPPROTO_RAW)) == -1) {
			printError("socket()");
		}
		else {
			char * cmd = argv[1];
			if (!strcmp(cmd, "set")) {
				ret = set(argc, argv,sockfd);
			}
			else if (!strcmp(cmd, "get")) {
				ret = get(argc, argv,sockfd);
			}
			else {
				usage(argv[0]);
			}
			close(sockfd);
		}
	}
	return ret;
}

四、指令手册

//设置指令
<set> <debug debug_level>|<ip input|output|all|open|close>[ip_addr a|r]|<port input|output|all|open|close>[port_no a|r]|<protocol protocol_name a|r>

//回显指令
<get> <debug>|<ip>|<port>|<protocol_name>

五、程序测试

测试环境

防火墙部署主机:ubuntu20.04

防火墙编译以及启动

使用make编译防火墙myfw.c

#Makefile

# Makefile 4.0
obj-m := myfw.o
CURRENT_PATH := $(shell pwd)
LINUX_KERNEL := $(shell uname -r)
LINUX_KERNEL_PATH := /usr/src/linux-headers-$(LINUX_KERNEL)

all:
	make -C $(LINUX_KERNEL_PATH) M=$(CURRENT_PATH) modules
clean:
	make -C $(LINUX_KERNEL_PATH) M=$(CURRENT_PATH) clean

编译防火墙控制程序

gcc myfwctl.c -o myfwctl

启动防火墙

sudo insmod myfw.ko

对于防火墙功能的测试验证,可以使用wireshark等监听工具实时监听通信流量。

posted @ 2023-04-15 14:06  PIAOMIAO1  阅读(375)  评论(0编辑  收藏  举报