openGauss源码解析(187)

openGauss源码解析:AI技术(34)

8.7.4 基于MADlib框架的扩展

前文展示了MADlib各个模块的功能和作用,从结构上看,用户可以针对自己的算法进行扩展。前文中提到的XGBoost、GBDT和Prophet三个算法是我们在原来基础上扩展的算法。本小节将以自研的GBDT模块为例,介绍基于MADlib框架的扩展。

GBDT目录结构

GBDT文件结构如表8-17所示。

表8-17 GBDT算法

文件结构

说明

gbdt/gbdt.py_in

python代码

gbdt/gbdt.sql_in

存储过程代码

gbdt/test/gbdt.sql

测试代码

在sql_in文件中,定义上层SQL-like接口,使用PL/pgSQL或者PL/python实现。

在SQL层中定义UDF函数,下述代码实现了类似重载的功能。

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gbdt_train(

training_table_name TEXT,

output_table_name TEXT,

id_col_name TEXT,

dependent_variable TEXT,

list_of_features TEXT,

list_of_features_to_exclude TEXT,

weights TEXT

)

RETURNS VOID AS $$

SELECT MADLIB_SCHEMA.gbdt_train($1, $2, $3, $4, $5, $6, $7, 30::INTEGER);

$$ LANGUAGE sql VOLATILE;

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gbdt_train(

training_table_name TEXT,

output_table_name TEXT,

id_col_name TEXT,

dependent_variable TEXT,

list_of_features TEXT,

list_of_features_to_exclude TEXT

)

RETURNS VOID AS $$

SELECT MADLIB_SCHEMA.gbdt_train($1, $2, $3, $4, $5, $6, NULL::TEXT);

$$ LANGUAGE sql VOLATILE;

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gbdt_train(

training_table_name TEXT,

output_table_name TEXT,

id_col_name TEXT,

dependent_variable TEXT,

list_of_features TEXT

)

RETURNS VOID AS $$

SELECT MADLIB_SCHEMA.gbdt_train($1, $2, $3, $4, $5, NULL::TEXT);

$$ LANGUAGE sql VOLATILE;

其中,输入表、输出表、特征等必备信息需要用户指定。其他参数提供缺省的参数,比如权重weights,如果用户没有指定自定义参数,程序会用默认的参数进行运算。

在SQL层定义PL/python接口,代码如下:

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gbdt_train(

training_table_name TEXT,

output_table_name TEXT,

id_col_name TEXT,

dependent_variable TEXT,

list_of_features TEXT,

list_of_features_to_exclude TEXT,

weights TEXT,

num_trees INTEGER,

num_random_features INTEGER,

max_tree_depth INTEGER,

min_split INTEGER,

min_bucket INTEGER,

num_bins INTEGER,

null_handling_params TEXT,

is_classification BOOLEAN,

predict_dt_prob TEXT,

learning_rate DOUBLE PRECISION,

verbose BOOLEAN,

sample_ratio DOUBLE PRECISION

)

RETURNS VOID AS $$

PythonFunction(gbdt, gbdt, gbdt_fit)

$$ LANGUAGE plpythonu VOLATILE;

PL/pgSQL或者SQL函数最终会调用到一个PL/python函数。

“PythonFunction(gbdt, gbdt, gbdt_fit)”是固定的用法,这也是一个封装的m4宏,会在编译安装的时候,会进行宏替换。

PythonFunction中,第一个参数是文件夹名,第二个参数是文件名,第三个参数是函数名。PythonFunction宏会被替换为“from gdbt.gdbt import gbdt_fit”语句。所以要保证文件路径和函数正确。

在python层中,实现训练函数,代码如下:

def gbdt_fit(schema_madlib,training_table_name, output_table_name,

id_col_name, dependent_variable, list_of_features,

list_of_features_to_exclude, weights,

num_trees, num_random_features,

max_tree_depth, min_split, min_bucket, num_bins,

null_handling_params, is_classification,

predict_dt_prob = None, learning_rate = None,

verbose=False, **kwargs):

plpy.execute("""ALTER TABLE {training_table_name} DROP COLUMN IF EXISTS gradient CASCADE

""".format(training_table_name=training_table_name))

create_summary_table(output_table_name, null_proxy, bins['cat_features'],

bins['con_features'], learning_rate, is_classification, predict_dt_prob,

num_trees, training_table_name)

在python层实现预测函数,代码如下:

def gbdt_predict(schema_madlib, test_table_name, model_table_name, output_table_name, id_col_name, **kwargs):

num_tree = plpy.execute("""SELECT COUNT(*) AS count FROM {model_table_name}""".format(**locals()))[0]['count']

if num_tree == 0:

plpy.error("The GBDT-method has no trees")

elements = plpy.execute("""SELECT * FROM {model_table_name}_summary""".format(**locals()))[0]

在py_in文件中,定义相应的业务代码,用python实现相应处理逻辑。

在安装阶段,sql_in和py_in会被GNU m4解析为正常的python和sql文件。这里需要指出的是,当前MADlib框架只支持python2版本,因此,上述代码实现也是基于python2完成的。

posted @ 2024-05-06 10:45  openGauss-bot  阅读(2)  评论(0编辑  收藏  举报