golang 防SQL注入 基于反射、TAG标记实现的不定参数检查器

  收到一个任务,所有http的handler要对入参检查,防止SQL注入。刚开始笨笨的,打算为所有的结构体写一个方法,后来统计了下,要写几十上百,随着业务增加,以后还会重复这个无脑力的机械劳作。想想就low。

  直接做一个不定参数的自动检测函数不就ok了么?

  磨刀不误砍柴工,用了一个下午的时间,调教出一个算法:把不定结构体对象扔进去,这个函数自动检查。

  普通场景还好,不比电信级业务,比如FRR快切,要求50ms以内刷新百万路由。

  先说说我的想法,然后把代码贴后面。

 

  这里犹豫,要不要做并发?就要看需求了。
  需求0:调用者传入一个结构体对象,要检查这个对象有没有变量注入脚本,没必要并发;
  需求1:调用者传入多个结构体对象,要检查这些对象有没有变量注入脚本,不能明确是哪个对象的变量有误,需要并发。这种情况的话,感觉简单的SQL检查用不上,牛刀杀鸡了;
  拓展0:如果想对TAG的长度做限制,没必要在const里面多定义几个限制长度的变量,直接把TAG改成TAG.len就ok了,在迭代器里面把小数点后面的长度提出来,扔到具体的检查函数里去。
  拓展1:这个函数可以改装成自动映射器:工具自动生成映射代码,插入检查器中对basic type的switch里,根据tag自动映射。省去程序员的机械编码,映射部分全自动。
  拓展2:只要一个结构体中,对变量进行tag自定义,就可以对这个结构体的所有变量进行任意处理。生产工具发展生产力!

 

  函数缺陷:高并发场景下,可能会有性能瓶颈,毕竟用了递归,栈空间吃紧,并且影响程序的可理解性,对程序的测试也有一定影响。

 

  个人的想象力总归是有限的,读者如果有什么更烧脑,异想天开的想法,可以留言,一起分析,一起进步。

   

  好了,贴代码吧:

定义了一个三层的结构体。

type ZZZStu struct {
	ggg int    `sql:"int"`
	hhh string `sql:"email"`
}

type YYYStu struct {
	ddd int    `sql:"int"`
	eee string `sql:"alphaandnum"`
	zzz ZZZStu
}

type XXXStu struct {
	aaa int      `sql:"int"`
	bbb []string `sql:"num"`
	yyy YYYStu
}

 

然后在main里面定义了这个结构体对象实例,在里面随意添加一些非法字符,调试测试使用。

func main() {
	var temp XXXStu
	temp.aaa = 1
	//temp.bbb[0] = "123"
	bbb_tmp := "1"
	temp.bbb = append(temp.bbb, bbb_tmp)
	bbb_tmp = "2"
	temp.bbb = append(temp.bbb, bbb_tmp)
	bbb_tmp = "3"
	temp.bbb = append(temp.bbb, bbb_tmp)
	temp.yyy.ddd = 3
	temp.yyy.eee = "123qwe"
	temp.yyy.zzz.ggg = 5
	temp.yyy.zzz.hhh = `123456789@xxxxxxx.com`
	addMsg, ret := CheckSqlInject(temp)
	fmt.Println("main:", addMsg, ret)
	return
}

 

 下面是这个防SQL检查器的最外层封装。

/*****************************************************************************
* \author          pxx
* \date             2018/07/05
* \brief		防sql注入检查递归迭代器
* \param[in]	不定参数
* \return		给定结构体对象内的变量,有非法字符
* \ingroup
* \remarks
******************************************************************************/
func CheckSqlInject(args ...interface{}) (addMsg string, ret int) {

	for _, arg := range args {
		name := reflect.TypeOf(arg).Name()
		fmt.Printf("Recursioner %s (%T):\n", name, arg)
		addMsg, ret = Recursioner(reflect.ValueOf(arg), name, name)
	}

	return
}

  

  

下面的Recursioner就是整个递归检查器的核心部分了。可以看出来,我是从struct类型起始的,因为reflect包里面只有structfield有TAG。

如果想拓展,就得在这个本包里面实现,或者在公司内部的库包里面做。这里不用三方库,有开源代码安全性问题的考量。

基本把所有类型涵盖了:指针,接口,channel,数组,切片,结构体,map,baisc type

func Recursioner(FieldValue reflect.Value, Path, FieldName string) (addMsg string, ret int) {
	switch FieldValue.Kind() {
	case reflect.Invalid:
		fmt.Printf("%s = invalid\n", Path)

	//struct为起点(只有StructField有TAG),暂时满足需求。如果以其他类型起始,需要对底层库函数拓展,有时间再搞
	case reflect.Struct:
		for i := 0; i < FieldValue.NumField(); i++ {
			fieldInfo := FieldValue.Type().Field(i)
			tag := fieldInfo.Tag //  reflect.StructTag(string)
			name := tag.Get("sql")
			fieldPath := fmt.Sprintf("%s.%s (%s)", Path, FieldValue.Type().Field(i).Name, FieldValue.Type().Field(i).Type)
			addMsg, ret = Recursioner(FieldValue.Field(i), fieldPath, name)
			if ret != 0 {
				return
			}
		}

	case reflect.Slice, reflect.Array:
		for i := 0; i < FieldValue.Len(); i++ {
			addMsg, ret = Recursioner(FieldValue.Index(i), fmt.Sprintf("%s[%d]", Path, i), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Map:
		for _, key := range FieldValue.MapKeys() {
			addMsg, ret = Recursioner(FieldValue.MapIndex(key), fmt.Sprintf("%s[%s]", Path,
				formatAtom(key)), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Ptr:
		if FieldValue.IsNil() {
			fmt.Printf("%s = nil\n", Path)
		} else {
			addMsg, ret = Recursioner(FieldValue.Elem(), fmt.Sprintf("(*%s)", Path), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Interface:
		if FieldValue.IsNil() {
			fmt.Printf("%s = nil\n", Path)
		} else {
			fmt.Printf("%s.type = %s\n", Path, FieldValue.Elem().Type())
			addMsg, ret = Recursioner(FieldValue.Elem(), Path+".value", FieldName)
			if ret != 0 {
				return
			}
		}

	default: // basic types, channels, funcs
		fmt.Printf("%s = %s\n", Path, formatAtom(FieldValue))

		field_name := FieldValue.Type().Name()
		if field_name == "string" {

			//获取该属性的tag
			fmt.Println("tag_value=", FieldName)
			switch FieldName {
			case "alphaandnum":
				addMsg, ret = CheckAlphaAndNum(CheckAlphaAndNumLen, formatAtom(FieldValue))
				if ret != 0 {
					return
				}

			case "email":
				addMsg, ret = CheckEmail(CheckEmailLen, formatAtom(FieldValue))
				if ret != 0 {
					return addMsg, ret
				}

			case "num":
				addMsg, ret = CheckNum(CheckNumLen, formatAtom(FieldValue))
				if ret != 0 {
					return
				}
			}
			fmt.Println()
		}
	}
	return
}

  

 

  

下面这个函数,格式化数据。

func formatAtom(FieldValue reflect.Value) string {
	switch FieldValue.Kind() {
	case reflect.Invalid:
		return "invalid"

	case reflect.String:
		return FieldValue.String()

	case reflect.Int, reflect.Int8, reflect.Int16,
		reflect.Int32, reflect.Int64:
		return strconv.FormatInt(FieldValue.Int(), 10)

	case reflect.Uint, reflect.Uint8, reflect.Uint16,
		reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		return strconv.FormatUint(FieldValue.Uint(), 10)

	// ...floating-point and complex cases omitted for brevity...
	case reflect.Bool:
		return strconv.FormatBool(FieldValue.Bool())

	case reflect.Chan, reflect.Func, reflect.Ptr, reflect.Slice, reflect.Map:
		return FieldValue.Type().String() + " 0x" +
			strconv.FormatUint(uint64(FieldValue.Pointer()), 16)

	default: // reflect.Array, reflect.Struct, reflect.Interface
		return FieldValue.Type().String() + " value"
	}
}

  

具体的字段检查函数,我贴一个就好,意思到了就行。原理也很简单,用golang自带的map(java里是hashmap,python里的dict,之前做的芯片驱动层表项管理,也是hash)。

func CheckAlphaAndNum(lenLimit int, str string) (addMsg string, ret int) {
	ret = 0
	var lenStr int = len(str)
	if lenStr > lenLimit {
		ret = -1
		return
	}
	for i := 0; i < lenStr; i++ {
		r := str[i]
		if _, ok := CheckAlphaAndNumMap[r]; !ok {
			ret = -1
			addMsg = "字母数字组合类型字符串包含非法字符,请检查!"
			return
		}
	}
	return
}

  

posted @ 2018-07-09 16:59  morning_sun  阅读(1357)  评论(0编辑  收藏  举报