ControlNet-trt优化总结4:onnx图修改与重建
ControlNet-trt优化总结4:onnx图修改与重建
在这一节中,主要总结网络层面的优化,针对于算子插件优化,主要聚焦于以下几点:
- 修改onnx图,添加不支持的算子插件
- 增加前后处理部分,前后处理导出为onnx图
onnx图surgeon
原有的graph中存在大量的GN操作,正常fp32的时候没有问题,但是当使用fp16时,由于GN中存在pow、exp等操作就会精度溢出,使得计算结果不准确。
一种方式就是手动改写添加GN算子,第一步就是要对onnx图进行surgeon操作,在原有的onnx图中插入GN算子,不过由于onnx的opset会把GN转化为IN+MM的方式处理,所以整个过程要分为两步,第一步是将IN分解为mean-sub-pow的形式,第二步则是将对应的算子模式重新捏回去为GN算子。
示例代码如下,这里分解代码分为3步,第一步是将原节点的输入数据和属性数据取出来,第二步是建立新的节点列表,代替原有的算子运算,第三步是断开原有节点的前后连接,并将连接新节点的前后连接,有点类似于链表操作。
def decompose_instancenorms(graph):
nRemoveInstanceNorm = 0
for node in graph.nodes:
if node.op == "InstanceNormalization":
name = node.name + "/"
input_tensor = node.inputs[0]
output_tensor = node.outputs[0]
mean_out = gs.Variable(name=name + "mean_out")
mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
sub_out = gs.Variable(name=name + "sub_out")
sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
pow_out = gs.Variable(name=name + "pow_out")
pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
mean2_out = gs.Variable(name=name + "mean2_out")
mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
epsilon_out = gs.Variable(name=name + "epsilon_out")
epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
sqrt_out = gs.Variable(name=name + "sqrt_out")
sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
div_out = gs.Variable(name=name + "div_out")
div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
mul_out = gs.Variable(name=name + "mul_out")
mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
node.inputs = []
node.outputs = []
nRemoveInstanceNorm += 1
graph.cleanup().toposort()
print("remove IN")
print(nRemoveInstanceNorm)
return graph
捏算子的过程与分解算子的过程类似,只不过是反回来的,这里需要注意的是要和cuda算子插件的属性、输入输出参数保持一致,否则构建时将找不到对应插件。
def insert_groupnorm_plugin(graph):
nGroupNormPlugin = 0
for node in graph.nodes:
if node.op == "Reshape" and node.outputs != [] and \
node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3 :
assert len(node.outputs) == 1
inputTensor = node.inputs[0]
gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
betaNode = gammaNode.o()
index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]
if betaNode.o().op == "Sigmoid": # need Swish
bSwish = True
lastNode = betaNode.o().o() # Mul node of Swish
else:
bSwish = False
lastNode = betaNode # Cast node after Group Norm
if lastNode.o().op == "Cast":
lastNode = lastNode.o()
inputList = [inputTensor, constantGamma, constantBeta]
groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
graph.nodes.append(groupNormN)
for subNode in graph.nodes:
if lastNode.outputs[0] in subNode.inputs:
index = subNode.inputs.index(lastNode.outputs[0])
subNode.inputs[index] = groupNormV
lastNode.outputs = []
nGroupNormPlugin += 1
graph.cleanup().toposort()
print("GroupNorm")
print(nGroupNormPlugin)
return graph
对于fp16溢出的另外一种处理方式是,将对应算子的前一层和当前层都使用高精度表示,示例代码中是对softmax的精度溢出进行处理,将前一层和当前层使用fp32来运算。
for i, i_next in pairwise(indices):
layer = trt_network.get_layer(i)
next_layer = trt_network.get_layer(i_next)
layer = trt_network.get_layer(i)
if not all([
layer.get_output(i).is_execution_tensor
for i in range(layer.num_outputs)
]):
continue
if layer.get_output_type(0) != trt.float32:
continue
if next_layer.type == trt.LayerType.SOFTMAX:
layer.precision = trt.DataType.FLOAT
next_layer.precision = trt.DataType.FLOAT
还有一种溢出情况是,一些算子的属性过大过小导致的溢出,这时需要将对应算子的属性由原有的inf调整为一个较小的数,在示例代码中便是将-np.inf调整为-1e4:
# change onnx -inf to -1e4
for node in new_onnx_model.graph.node:
if node.op_type == "ConstantOfShape":
attr = node.attribute[0]
if attr.name == "value" and attr.t.data_type == onnx.TensorProto.FLOAT:
np_array = np.frombuffer(attr.t.raw_data, dtype=np.float32).copy()
np_array[np_array == -np.inf] = -100000 # 将所有负无穷的值改为-100000
attr.t.raw_data = np_array.tobytes()
前后处理onnx图
这个不算是特别大的加速,但是是一种让人眼前一新的trick。主要的点在于DDIM过程中,controlnet之后会有一段后处理,把这段前后处理部分由原本的torch计算换成onnx图,这样便也可以通过trt进行加速,即后处理部分转化为了一个postnet的图。这里有个问题是,由于迭代的次数不一样,所以对应的参数也不一样,好的做法combine一个更大的图,避免额外的参数。
class PostNet(nn.Module):
def __init__(self):
super().__init__()
# step = 20
# self.alphas = torch.from_numpy(np.array([0.9983, 0.9505, 0.8930, 0.8264, 0.7521, 0.6722, 0.5888, 0.5048, 0.4229,0.3456, 0.2750,
# 0.2128, 0.1598, 0.1163, 0.0819, 0.0557, 0.0365, 0.0231,0.0140, 0.0082]))
# self.alphas_prev = torch.from_numpy(np.array([0.99914998,0.99829602, 0.95052433, 0.89298052, 0.82639927, 0.75214338,
# 0.67215145, 0.58881873, 0.50481856, 0.42288151, 0.34555823, 0.27499905,
# 0.21278252, 0.15981644, 0.11632485, 0.08191671, 0.05571903, 0.03654652,
# 0.02307699, 0.0140049 ]))
# self.sqrt_one_minus_alphas = torch.from_numpy(np.array([0.0413, 0.2224, 0.3271, 0.4167, 0.4979, 0.5726, 0.6412, 0.7037, 0.7597,
# 0.8090, 0.8515, 0.8873, 0.9166, 0.9400, 0.9582, 0.9717, 0.9816, 0.9884,
# 0.9930, 0.9959]))
# self.sigmas = torch.from_numpy(np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
# self.time_range = [951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101,51, 1]
# step = 10
self.alphas = torch.from_numpy(np.array([0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365,0.0140]))
self.alphas_prev = torch.from_numpy(np.array([0.99914998, 0.99829602, 0.89298052, 0.75214338, 0.58881873, 0.42288151,0.27499905, 0.15981644, 0.08191671, 0.03654652]))
self.sqrt_one_minus_alphas = torch.from_numpy(np.array([0.0413, 0.3271, 0.4979, 0.6412, 0.7597, 0.8515, 0.9166, 0.9582, 0.9816,
0.9930]))
self.sigmas = torch.from_numpy(np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
def forward(self,x,image,unconditional_guidance_scale,index):
e_t = image[1].unsqueeze(0) + unconditional_guidance_scale * (image[0].unsqueeze(0) - image[1].unsqueeze(0))
a_t = self.alphas[index]
a_prev = self.alphas_prev[index]
sqrt_one_minus_at = self.sqrt_one_minus_alphas[index]
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0