1 #include <linux/module.h>
  2 #include <linux/init.h>
  3 #include <linux/types.h>
  4 #include <asm/uaccess.h>
  5 #include <asm/cacheflush.h>
  6 #include <linux/unistd.h>
  7 #include <linux/syscalls.h>
  8 #include <asm/syscall.h>
  9 #include <linux/delay.h>    /* loops_per_jiffy */
 10 #include <linux/net.h>      /* kernel_sendmsg */
 11 #include <linux/string.h>   /* memset */
 12 #include <linux/file.h>     /* fget */
 13 #include <linux/splice.h>
 14 #include <linux/pagemap.h>
 15 #include <linux/fs.h>
 16 #include <linux/version.h>
 17 #include <linux/kprobes.h>
 18 #include <net/tcp.h>
 19 #include <net/sock.h>
 20 #include <net/inet_common.h>
 21 #include <linux/skbuff.h>
 22 #include <net/protocol.h>
 23 #include <net/udp.h>
 24 #include <net/udplite.h>
 25 #include <net/xfrm.h>
 26 #include <linux/igmp.h>
 27 #include <net/icmp.h>
 28 
 29 #include "udp_lib_mcast.h"
 30 
 31 #define MODULE_NAME                 "[udp_mcast]"
 32 #define PRINT(level, fmt, ...)      printk(level MODULE_NAME fmt "\n", ## __VA_ARGS__)
 33 
 34 
 35 struct net_protocol *ori_udp_protocol = NULL;
 36 typedef int (* udp_rcv_t)(struct sk_buff *skb);
 37 udp_rcv_t ori_udp_rcv = NULL;
 38 
 39 typedef int (* udp_queue_rcv_skb_t)(struct sock *sk, struct sk_buff *skb);
 40 udp_queue_rcv_skb_t ori_udp_queue_rcv_skb = NULL;
 41 
 42 typedef int (* ip_mc_sf_allow_t)(struct sock *sk, __be32 local, __be32 rmt, int dif);
 43 ip_mc_sf_allow_t ori_ip_mc_sf_allow = NULL;
 44 
 45 
 46 static void flush_stack(struct sock **stack, unsigned int count,
 47             struct sk_buff *skb, unsigned int final)
 48 {
 49     unsigned int i;
 50     struct sk_buff *skb1 = NULL;
 51     struct sock *sk;
 52 
 53     for (i = 0; i < count; i++) {
 54         sk = stack[i];
 55         if (likely(skb1 == NULL))
 56             skb1 = (i == final) ? skb : skb_clone(skb, GFP_ATOMIC);
 57 
 58         if (!skb1) {
 59             atomic_inc(&sk->sk_drops);
 60             UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS,
 61                      IS_UDPLITE(sk));
 62             UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_INERRORS,
 63                      IS_UDPLITE(sk));
 64         }
 65 
 66         if (skb1 && ori_udp_queue_rcv_skb(sk, skb1) <= 0)
 67             skb1 = NULL;
 68     }
 69     if (unlikely(skb1))
 70         kfree_skb(skb1);
 71 }
 72 
 73 /* Initialize UDP checksum. If exited with zero value (success),
 74  * CHECKSUM_UNNECESSARY means, that no more checks are required.
 75  * Otherwise, csum completion requires chacksumming packet body,
 76  * including udp header and folding it to skb->csum.
 77  */
 78 static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, int proto)
 79 {
 80     int err;
 81 
 82     UDP_SKB_CB(skb)->partial_cov = 0;
 83     UDP_SKB_CB(skb)->cscov = skb->len;
 84 
 85     if (proto == IPPROTO_UDPLITE) {
 86         err = udplite_checksum_init(skb, uh);
 87         if (err)
 88             return err;
 89     }
 90 
 91     return skb_checksum_init_zero_check(skb, proto, uh->check,
 92                         inet_compute_pseudo);
 93 }
 94 
 95 static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
 96                          __be16 sport, __be16 dport,
 97                          struct udp_table *udptable)
 98 {
 99     struct sock *sk;
100     const struct iphdr *iph = ip_hdr(skb);
101 
102     if (unlikely(sk = skb_steal_sock(skb)))
103         return sk;
104     else
105         return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
106                      iph->daddr, dport, inet_iif(skb),
107                      udptable);
108 }
109 
110 static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
111                          __be16 loc_port, __be32 loc_addr,
112                          __be16 rmt_port, __be32 rmt_addr,
113                          int dif)
114 {
115     struct hlist_nulls_node *node;
116     struct sock *s = sk;
117     unsigned short hnum = ntohs(loc_port);
118 
119     sk_nulls_for_each_from(s, node) {
120         struct inet_sock *inet = inet_sk(s);
121 
122         if (!net_eq(sock_net(s), net) ||
123             udp_sk(s)->udp_port_hash != hnum ||
124             (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
125             (inet->inet_dport != rmt_port && inet->inet_dport) ||
126             (inet->inet_rcv_saddr &&
127              inet->inet_rcv_saddr != loc_addr) ||
128             ipv6_only_sock(s) ||
129             (s->sk_bound_dev_if && s->sk_bound_dev_if != dif))
130             continue;
131         if (!ori_ip_mc_sf_allow(s, loc_addr, rmt_addr, dif))
132             continue;
133         goto found;
134     }
135     s = NULL;
136 found:
137     return s;
138 }
139 
140 static unsigned int udp4_portaddr_hash(struct net *net, __be32 saddr, unsigned int port)
141 {
142     return jhash_1word((__force u32)saddr, net_hash_mix(net)) ^ port;
143 }
144 
145 #define sk_nulls_for_each_entry_offset(tpos, pos, head, offset)               \
146     for (pos = (head)->first;                           \
147          (!is_a_nulls(pos)) &&                           \
148         ({ tpos = (typeof(*tpos) *)((void *)pos - offset); 1;});       \
149          pos = pos->next)
150          
151 static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
152                        __be16 loc_port, __be32 loc_addr,
153                        __be16 rmt_port, __be32 rmt_addr,
154                        int dif, unsigned short hnum)
155 {
156     struct inet_sock *inet = inet_sk(sk);
157 
158     if (!net_eq(sock_net(sk), net) ||
159         udp_sk(sk)->udp_port_hash != hnum ||
160         (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
161         (inet->inet_dport != rmt_port && inet->inet_dport) ||
162         (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
163         ipv6_only_sock(sk) ||
164         (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
165         return false;
166     if (!ori_ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
167         return false;
168     return true;
169 }         
170 
171 static int my__udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
172                     struct udphdr  *uh,
173                     __be32 saddr, __be32 daddr,
174                     struct udp_table *udptable)
175 {
176     struct sock *sk, *stack[256 / sizeof(struct sock *)];
177     struct hlist_nulls_node *node;
178     unsigned short hnum = ntohs(uh->dest);
179     struct udp_hslot *hslot = udp_hashslot(udptable, net, hnum);
180     int dif = skb->dev->ifindex;
181     unsigned int count = 0, offset = offsetof(typeof(*sk), sk_nulls_node);
182     unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
183 
184     if (use_hash2) {
185         hash2_any = udp4_portaddr_hash(net, htonl(INADDR_ANY), hnum) &
186                 udp_table.mask;
187         hash2 = udp4_portaddr_hash(net, daddr, hnum) & udp_table.mask;
188 start_lookup:
189         hslot = &udp_table.hash2[hash2];
190         offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node);
191     }
192 
193     spin_lock(&hslot->lock);
194     sk_nulls_for_each_entry_offset(sk, node, &hslot->head, offset) {
195         if (__udp_is_mcast_sock(net, sk,
196                     uh->dest, daddr,
197                     uh->source, saddr,
198                     dif, hnum)) {
199             if (unlikely(count == ARRAY_SIZE(stack))) {
200                 flush_stack(stack, count, skb, ~0);
201                 count = 0;
202             }
203             stack[count++] = sk;
204             sock_hold(sk);
205         }
206     }
207 
208     spin_unlock(&hslot->lock);
209 
210     /* Also lookup *:port if we are using hash2 and haven't done so yet. */
211     if (use_hash2 && hash2 != hash2_any) {
212         hash2 = hash2_any;
213         goto start_lookup;
214     }
215 
216     /*
217      * do the slow work with no lock held
218      */
219     if (count) {
220         flush_stack(stack, count, skb, count - 1);
221     } else {
222         kfree_skb(skb);
223     }
224     return 0;
225 }
226 
227 
228 int ori__udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
229            int proto)
230 {
231     struct sock *sk;
232     struct udphdr *uh;
233     unsigned short ulen;
234     struct rtable *rt = skb_rtable(skb);
235     __be32 saddr, daddr;
236     struct net *net = dev_net(skb->dev);
237 
238     /*
239      *  Validate the packet.
240      */
241     if (!pskb_may_pull(skb, sizeof(struct udphdr)))
242         goto drop;        /* No space for header. */
243 
244     uh   = udp_hdr(skb);
245     ulen = ntohs(uh->len);
246     saddr = ip_hdr(skb)->saddr;
247     daddr = ip_hdr(skb)->daddr;
248 
249     if (ulen > skb->len)
250         goto short_packet;
251 
252     if (proto == IPPROTO_UDP) {
253         /* UDP validates ulen. */
254         if (ulen < sizeof(*uh) || pskb_trim_rcsum(skb, ulen))
255             goto short_packet;
256         uh = udp_hdr(skb);
257     }
258 
259     if (udp4_csum_init(skb, uh, proto))
260         goto csum_error;
261 
262     if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
263         return my__udp4_lib_mcast_deliver(net, skb, uh, saddr, daddr, udptable);
264 
265     sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
266 
267     if (sk != NULL) {
268         int ret;
269 
270         if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
271             skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
272                          inet_compute_pseudo);
273 
274         ret = ori_udp_queue_rcv_skb(sk, skb);
275         sock_put(sk);
276 
277         /* a return value > 0 means to resubmit the input, but
278          * it wants the return to be -protocol, or 0
279          */
280         if (ret > 0)
281             return -ret;
282         return 0;
283     }
284 
285     if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
286         goto drop;
287     nf_reset(skb);
288 
289     /* No socket. Drop packet silently, if checksum is wrong */
290     if (udp_lib_checksum_complete(skb))
291         goto csum_error;
292 
293     UDP_INC_STATS_BH(net, UDP_MIB_NOPORTS, proto == IPPROTO_UDPLITE);
294     icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0);
295 
296     /*
297      * Hmm.  We got an UDP packet to a port to which we
298      * don't wanna listen.  Ignore it.
299      */
300     kfree_skb(skb);
301     return 0;
302 
303 short_packet:
304     LIMIT_NETDEBUG(KERN_DEBUG "UDP%s: short packet: From %pI4:%u %d/%d to %pI4:%u\n",
305                proto == IPPROTO_UDPLITE ? "Lite" : "",
306                &saddr, ntohs(uh->source),
307                ulen, skb->len,
308                &daddr, ntohs(uh->dest));
309     goto drop;
310 
311 csum_error:
312     /*
313      * RFC1122: OK.  Discards the bad packet silently (as far as
314      * the network is concerned, anyway) as per 4.1.3.4 (MUST).
315      */
316     LIMIT_NETDEBUG(KERN_DEBUG "UDP%s: bad checksum. From %pI4:%u to %pI4:%u ulen %d\n",
317                proto == IPPROTO_UDPLITE ? "Lite" : "",
318                &saddr, ntohs(uh->source), &daddr, ntohs(uh->dest),
319                ulen);
320     UDP_INC_STATS_BH(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE);
321 drop:
322     UDP_INC_STATS_BH(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE);
323     kfree_skb(skb);
324     return 0;
325 }
326 
327 
328 int my_udp_rcv(struct sk_buff *skb)
329 {
330     /* udp_table exported */
331     return ori__udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP);
332 }
333 
334 
335 static void disable_write_protection(void)
336 {
337   unsigned long cr0 = read_cr0();
338   clear_bit(16, &cr0);
339   write_cr0(cr0);
340 }
341 
342 static  void enable_write_protection(void)
343 {
344   unsigned long cr0 = read_cr0();
345   set_bit(16, &cr0);
346   write_cr0(cr0);
347 }
348 
349 static inline int replace_udp_protocol_handler(struct net_protocol *proto, udp_rcv_t old_handler, udp_rcv_t new_handler)
350 {
351     if (!proto || !old_handler || !new_handler) {
352         printk(KERN_ERR "parameters maybe NULL.\n");
353         return -EINVAL;
354     }
355 
356     printk(KERN_ERR "proto = %p, proto->handler = %p, old_handler = %p, new_handler = %p\n", proto, proto->handler, old_handler, new_handler);
357 
358     /*
359      * Atomic compare and exchange.  Compare OLD with MEM, if identical,
360      * store NEW in MEM.  Return the initial value in MEM.  Success is
361      * indicated by comparing RETURN with OLD.
362      */
363 
364     return old_handler == cmpxchg((udp_rcv_t *)&(proto->handler), old_handler, new_handler) ? 0 : -1;
365 }
366 
367 
368 static int __init udp_lib_mcast_init(void)
369 {
370     int ret = -1;
371     PRINT(KERN_ERR, "Load udp lib module");
372    
373     ori_udp_protocol = (struct net_protocol *)kallsyms_lookup_name("udp_protocol");
374     if (!ori_udp_protocol)
375     {
376         PRINT(KERN_ERR, "ori_udp_protocol failed!");
377         return -1;
378     }
379 
380     ori_udp_queue_rcv_skb = (udp_queue_rcv_skb_t)kallsyms_lookup_name("udp_queue_rcv_skb");
381     if(ori_udp_queue_rcv_skb == NULL)
382     {
383         PRINT(KERN_ERR, "ori_udp_queue_rcv_skb failed!");
384         return -1;
385     }
386     
387     ori_ip_mc_sf_allow = (ip_mc_sf_allow_t)kallsyms_lookup_name("ip_mc_sf_allow");
388     if(ori_ip_mc_sf_allow == NULL)
389     {
390         PRINT(KERN_ERR, "ori_ip_mc_sf_allow failed!");
391         return -1;
392     }
393     
394     ori_udp_rcv = (udp_rcv_t)kallsyms_lookup_name("udp_rcv");
395     if(ori_udp_rcv == NULL)
396     {
397         PRINT(KERN_ERR, "ori_udp_rcv failed!");
398         return -1;
399     }
400         
401     disable_write_protection();
402 
403     ret = replace_udp_protocol_handler(ori_udp_protocol, ori_udp_rcv, my_udp_rcv);
404     if(ret)
405     {
406         PRINT(KERN_ERR, "replace udp_protocol handler error!\n");
407         goto udp_rcv_replace_error;
408     }
409     enable_write_protection();
410     PRINT(KERN_ERR, "replace udp_protocol handler success!\n");
411 
412     PRINT(KERN_ERR, "Load udp lib module success!");
413     
414     return 0;
415 
416 udp_rcv_replace_error:
417     enable_write_protection();
418     return -1;
419 }
420 
421 static void __exit udp_lib_mcast_exit(void)
422 {
423     int rc = -1;
424     PRINT(KERN_ERR, "Unload udp_lib_mcast module start");
425     
426     disable_write_protection();
427 
428     rc = replace_udp_protocol_handler(ori_udp_protocol, my_udp_rcv, ori_udp_rcv);
429     if(rc)
430     {
431         PRINT(KERN_ERR, "Oops: resume udp_protocol handler error!\n");
432         //goto resume_handler_error;
433     }
434     
435     enable_write_protection();
436     PRINT(KERN_ERR, "Unload udp_lib_mcast module End\n");
437     
438     return;
439 }
440 
441 module_init(udp_lib_mcast_init);
442 module_exit(udp_lib_mcast_exit);
443 
444 MODULE_AUTHOR("liangjs");
445 MODULE_DESCRIPTION("Custom udp lib rcv for mcast 2020-06-24");
446 MODULE_VERSION("1.0.0"); 
447 MODULE_LICENSE("GPL");