1 #include <linux/module.h> 2 #include <linux/init.h> 3 #include <linux/types.h> 4 #include <asm/uaccess.h> 5 #include <asm/cacheflush.h> 6 #include <linux/unistd.h> 7 #include <linux/syscalls.h> 8 #include <asm/syscall.h> 9 #include <linux/delay.h> /* loops_per_jiffy */ 10 #include <linux/net.h> /* kernel_sendmsg */ 11 #include <linux/string.h> /* memset */ 12 #include <linux/file.h> /* fget */ 13 #include <linux/splice.h> 14 #include <linux/pagemap.h> 15 #include <linux/fs.h> 16 #include <linux/version.h> 17 #include <linux/kprobes.h> 18 #include <net/tcp.h> 19 #include <net/sock.h> 20 #include <net/inet_common.h> 21 #include <linux/skbuff.h> 22 #include <net/protocol.h> 23 #include <net/udp.h> 24 #include <net/udplite.h> 25 #include <net/xfrm.h> 26 #include <linux/igmp.h> 27 #include <net/icmp.h> 28 29 #include "udp_lib_mcast.h" 30 31 #define MODULE_NAME "[udp_mcast]" 32 #define PRINT(level, fmt, ...) printk(level MODULE_NAME fmt "\n", ## __VA_ARGS__) 33 34 35 struct net_protocol *ori_udp_protocol = NULL; 36 typedef int (* udp_rcv_t)(struct sk_buff *skb); 37 udp_rcv_t ori_udp_rcv = NULL; 38 39 typedef int (* udp_queue_rcv_skb_t)(struct sock *sk, struct sk_buff *skb); 40 udp_queue_rcv_skb_t ori_udp_queue_rcv_skb = NULL; 41 42 typedef int (* ip_mc_sf_allow_t)(struct sock *sk, __be32 local, __be32 rmt, int dif); 43 ip_mc_sf_allow_t ori_ip_mc_sf_allow = NULL; 44 45 46 static void flush_stack(struct sock **stack, unsigned int count, 47 struct sk_buff *skb, unsigned int final) 48 { 49 unsigned int i; 50 struct sk_buff *skb1 = NULL; 51 struct sock *sk; 52 53 for (i = 0; i < count; i++) { 54 sk = stack[i]; 55 if (likely(skb1 == NULL)) 56 skb1 = (i == final) ? skb : skb_clone(skb, GFP_ATOMIC); 57 58 if (!skb1) { 59 atomic_inc(&sk->sk_drops); 60 UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS, 61 IS_UDPLITE(sk)); 62 UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_INERRORS, 63 IS_UDPLITE(sk)); 64 } 65 66 if (skb1 && ori_udp_queue_rcv_skb(sk, skb1) <= 0) 67 skb1 = NULL; 68 } 69 if (unlikely(skb1)) 70 kfree_skb(skb1); 71 } 72 73 /* Initialize UDP checksum. If exited with zero value (success), 74 * CHECKSUM_UNNECESSARY means, that no more checks are required. 75 * Otherwise, csum completion requires chacksumming packet body, 76 * including udp header and folding it to skb->csum. 77 */ 78 static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, int proto) 79 { 80 int err; 81 82 UDP_SKB_CB(skb)->partial_cov = 0; 83 UDP_SKB_CB(skb)->cscov = skb->len; 84 85 if (proto == IPPROTO_UDPLITE) { 86 err = udplite_checksum_init(skb, uh); 87 if (err) 88 return err; 89 } 90 91 return skb_checksum_init_zero_check(skb, proto, uh->check, 92 inet_compute_pseudo); 93 } 94 95 static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, 96 __be16 sport, __be16 dport, 97 struct udp_table *udptable) 98 { 99 struct sock *sk; 100 const struct iphdr *iph = ip_hdr(skb); 101 102 if (unlikely(sk = skb_steal_sock(skb))) 103 return sk; 104 else 105 return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport, 106 iph->daddr, dport, inet_iif(skb), 107 udptable); 108 } 109 110 static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk, 111 __be16 loc_port, __be32 loc_addr, 112 __be16 rmt_port, __be32 rmt_addr, 113 int dif) 114 { 115 struct hlist_nulls_node *node; 116 struct sock *s = sk; 117 unsigned short hnum = ntohs(loc_port); 118 119 sk_nulls_for_each_from(s, node) { 120 struct inet_sock *inet = inet_sk(s); 121 122 if (!net_eq(sock_net(s), net) || 123 udp_sk(s)->udp_port_hash != hnum || 124 (inet->inet_daddr && inet->inet_daddr != rmt_addr) || 125 (inet->inet_dport != rmt_port && inet->inet_dport) || 126 (inet->inet_rcv_saddr && 127 inet->inet_rcv_saddr != loc_addr) || 128 ipv6_only_sock(s) || 129 (s->sk_bound_dev_if && s->sk_bound_dev_if != dif)) 130 continue; 131 if (!ori_ip_mc_sf_allow(s, loc_addr, rmt_addr, dif)) 132 continue; 133 goto found; 134 } 135 s = NULL; 136 found: 137 return s; 138 } 139 140 static unsigned int udp4_portaddr_hash(struct net *net, __be32 saddr, unsigned int port) 141 { 142 return jhash_1word((__force u32)saddr, net_hash_mix(net)) ^ port; 143 } 144 145 #define sk_nulls_for_each_entry_offset(tpos, pos, head, offset) \ 146 for (pos = (head)->first; \ 147 (!is_a_nulls(pos)) && \ 148 ({ tpos = (typeof(*tpos) *)((void *)pos - offset); 1;}); \ 149 pos = pos->next) 150 151 static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, 152 __be16 loc_port, __be32 loc_addr, 153 __be16 rmt_port, __be32 rmt_addr, 154 int dif, unsigned short hnum) 155 { 156 struct inet_sock *inet = inet_sk(sk); 157 158 if (!net_eq(sock_net(sk), net) || 159 udp_sk(sk)->udp_port_hash != hnum || 160 (inet->inet_daddr && inet->inet_daddr != rmt_addr) || 161 (inet->inet_dport != rmt_port && inet->inet_dport) || 162 (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) || 163 ipv6_only_sock(sk) || 164 (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)) 165 return false; 166 if (!ori_ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif)) 167 return false; 168 return true; 169 } 170 171 static int my__udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, 172 struct udphdr *uh, 173 __be32 saddr, __be32 daddr, 174 struct udp_table *udptable) 175 { 176 struct sock *sk, *stack[256 / sizeof(struct sock *)]; 177 struct hlist_nulls_node *node; 178 unsigned short hnum = ntohs(uh->dest); 179 struct udp_hslot *hslot = udp_hashslot(udptable, net, hnum); 180 int dif = skb->dev->ifindex; 181 unsigned int count = 0, offset = offsetof(typeof(*sk), sk_nulls_node); 182 unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10); 183 184 if (use_hash2) { 185 hash2_any = udp4_portaddr_hash(net, htonl(INADDR_ANY), hnum) & 186 udp_table.mask; 187 hash2 = udp4_portaddr_hash(net, daddr, hnum) & udp_table.mask; 188 start_lookup: 189 hslot = &udp_table.hash2[hash2]; 190 offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node); 191 } 192 193 spin_lock(&hslot->lock); 194 sk_nulls_for_each_entry_offset(sk, node, &hslot->head, offset) { 195 if (__udp_is_mcast_sock(net, sk, 196 uh->dest, daddr, 197 uh->source, saddr, 198 dif, hnum)) { 199 if (unlikely(count == ARRAY_SIZE(stack))) { 200 flush_stack(stack, count, skb, ~0); 201 count = 0; 202 } 203 stack[count++] = sk; 204 sock_hold(sk); 205 } 206 } 207 208 spin_unlock(&hslot->lock); 209 210 /* Also lookup *:port if we are using hash2 and haven't done so yet. */ 211 if (use_hash2 && hash2 != hash2_any) { 212 hash2 = hash2_any; 213 goto start_lookup; 214 } 215 216 /* 217 * do the slow work with no lock held 218 */ 219 if (count) { 220 flush_stack(stack, count, skb, count - 1); 221 } else { 222 kfree_skb(skb); 223 } 224 return 0; 225 } 226 227 228 int ori__udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, 229 int proto) 230 { 231 struct sock *sk; 232 struct udphdr *uh; 233 unsigned short ulen; 234 struct rtable *rt = skb_rtable(skb); 235 __be32 saddr, daddr; 236 struct net *net = dev_net(skb->dev); 237 238 /* 239 * Validate the packet. 240 */ 241 if (!pskb_may_pull(skb, sizeof(struct udphdr))) 242 goto drop; /* No space for header. */ 243 244 uh = udp_hdr(skb); 245 ulen = ntohs(uh->len); 246 saddr = ip_hdr(skb)->saddr; 247 daddr = ip_hdr(skb)->daddr; 248 249 if (ulen > skb->len) 250 goto short_packet; 251 252 if (proto == IPPROTO_UDP) { 253 /* UDP validates ulen. */ 254 if (ulen < sizeof(*uh) || pskb_trim_rcsum(skb, ulen)) 255 goto short_packet; 256 uh = udp_hdr(skb); 257 } 258 259 if (udp4_csum_init(skb, uh, proto)) 260 goto csum_error; 261 262 if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST)) 263 return my__udp4_lib_mcast_deliver(net, skb, uh, saddr, daddr, udptable); 264 265 sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); 266 267 if (sk != NULL) { 268 int ret; 269 270 if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk)) 271 skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check, 272 inet_compute_pseudo); 273 274 ret = ori_udp_queue_rcv_skb(sk, skb); 275 sock_put(sk); 276 277 /* a return value > 0 means to resubmit the input, but 278 * it wants the return to be -protocol, or 0 279 */ 280 if (ret > 0) 281 return -ret; 282 return 0; 283 } 284 285 if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) 286 goto drop; 287 nf_reset(skb); 288 289 /* No socket. Drop packet silently, if checksum is wrong */ 290 if (udp_lib_checksum_complete(skb)) 291 goto csum_error; 292 293 UDP_INC_STATS_BH(net, UDP_MIB_NOPORTS, proto == IPPROTO_UDPLITE); 294 icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); 295 296 /* 297 * Hmm. We got an UDP packet to a port to which we 298 * don't wanna listen. Ignore it. 299 */ 300 kfree_skb(skb); 301 return 0; 302 303 short_packet: 304 LIMIT_NETDEBUG(KERN_DEBUG "UDP%s: short packet: From %pI4:%u %d/%d to %pI4:%u\n", 305 proto == IPPROTO_UDPLITE ? "Lite" : "", 306 &saddr, ntohs(uh->source), 307 ulen, skb->len, 308 &daddr, ntohs(uh->dest)); 309 goto drop; 310 311 csum_error: 312 /* 313 * RFC1122: OK. Discards the bad packet silently (as far as 314 * the network is concerned, anyway) as per 4.1.3.4 (MUST). 315 */ 316 LIMIT_NETDEBUG(KERN_DEBUG "UDP%s: bad checksum. From %pI4:%u to %pI4:%u ulen %d\n", 317 proto == IPPROTO_UDPLITE ? "Lite" : "", 318 &saddr, ntohs(uh->source), &daddr, ntohs(uh->dest), 319 ulen); 320 UDP_INC_STATS_BH(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE); 321 drop: 322 UDP_INC_STATS_BH(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE); 323 kfree_skb(skb); 324 return 0; 325 } 326 327 328 int my_udp_rcv(struct sk_buff *skb) 329 { 330 /* udp_table exported */ 331 return ori__udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP); 332 } 333 334 335 static void disable_write_protection(void) 336 { 337 unsigned long cr0 = read_cr0(); 338 clear_bit(16, &cr0); 339 write_cr0(cr0); 340 } 341 342 static void enable_write_protection(void) 343 { 344 unsigned long cr0 = read_cr0(); 345 set_bit(16, &cr0); 346 write_cr0(cr0); 347 } 348 349 static inline int replace_udp_protocol_handler(struct net_protocol *proto, udp_rcv_t old_handler, udp_rcv_t new_handler) 350 { 351 if (!proto || !old_handler || !new_handler) { 352 printk(KERN_ERR "parameters maybe NULL.\n"); 353 return -EINVAL; 354 } 355 356 printk(KERN_ERR "proto = %p, proto->handler = %p, old_handler = %p, new_handler = %p\n", proto, proto->handler, old_handler, new_handler); 357 358 /* 359 * Atomic compare and exchange. Compare OLD with MEM, if identical, 360 * store NEW in MEM. Return the initial value in MEM. Success is 361 * indicated by comparing RETURN with OLD. 362 */ 363 364 return old_handler == cmpxchg((udp_rcv_t *)&(proto->handler), old_handler, new_handler) ? 0 : -1; 365 } 366 367 368 static int __init udp_lib_mcast_init(void) 369 { 370 int ret = -1; 371 PRINT(KERN_ERR, "Load udp lib module"); 372 373 ori_udp_protocol = (struct net_protocol *)kallsyms_lookup_name("udp_protocol"); 374 if (!ori_udp_protocol) 375 { 376 PRINT(KERN_ERR, "ori_udp_protocol failed!"); 377 return -1; 378 } 379 380 ori_udp_queue_rcv_skb = (udp_queue_rcv_skb_t)kallsyms_lookup_name("udp_queue_rcv_skb"); 381 if(ori_udp_queue_rcv_skb == NULL) 382 { 383 PRINT(KERN_ERR, "ori_udp_queue_rcv_skb failed!"); 384 return -1; 385 } 386 387 ori_ip_mc_sf_allow = (ip_mc_sf_allow_t)kallsyms_lookup_name("ip_mc_sf_allow"); 388 if(ori_ip_mc_sf_allow == NULL) 389 { 390 PRINT(KERN_ERR, "ori_ip_mc_sf_allow failed!"); 391 return -1; 392 } 393 394 ori_udp_rcv = (udp_rcv_t)kallsyms_lookup_name("udp_rcv"); 395 if(ori_udp_rcv == NULL) 396 { 397 PRINT(KERN_ERR, "ori_udp_rcv failed!"); 398 return -1; 399 } 400 401 disable_write_protection(); 402 403 ret = replace_udp_protocol_handler(ori_udp_protocol, ori_udp_rcv, my_udp_rcv); 404 if(ret) 405 { 406 PRINT(KERN_ERR, "replace udp_protocol handler error!\n"); 407 goto udp_rcv_replace_error; 408 } 409 enable_write_protection(); 410 PRINT(KERN_ERR, "replace udp_protocol handler success!\n"); 411 412 PRINT(KERN_ERR, "Load udp lib module success!"); 413 414 return 0; 415 416 udp_rcv_replace_error: 417 enable_write_protection(); 418 return -1; 419 } 420 421 static void __exit udp_lib_mcast_exit(void) 422 { 423 int rc = -1; 424 PRINT(KERN_ERR, "Unload udp_lib_mcast module start"); 425 426 disable_write_protection(); 427 428 rc = replace_udp_protocol_handler(ori_udp_protocol, my_udp_rcv, ori_udp_rcv); 429 if(rc) 430 { 431 PRINT(KERN_ERR, "Oops: resume udp_protocol handler error!\n"); 432 //goto resume_handler_error; 433 } 434 435 enable_write_protection(); 436 PRINT(KERN_ERR, "Unload udp_lib_mcast module End\n"); 437 438 return; 439 } 440 441 module_init(udp_lib_mcast_init); 442 module_exit(udp_lib_mcast_exit); 443 444 MODULE_AUTHOR("liangjs"); 445 MODULE_DESCRIPTION("Custom udp lib rcv for mcast 2020-06-24"); 446 MODULE_VERSION("1.0.0"); 447 MODULE_LICENSE("GPL");