Java Agent

Java Agent

引言

在本篇文章中,我会通过几个简单的程序来说明 agent 的使用,最后在实战环节我会通过 asm 字节码框架来实现一个小工具,用于在程序运行中采集指定方法的参数和返回值。有关 asm 字节码的内容不是本文的重点,不会过多的阐述,不明白的同学可以自己 google 下。

简介

Java agent 提供了一种在加载字节码时,对字节码进行修改的方式。他共有两种方式执行,一种是在 main 方法执行之前,通过 premain 来实现,另一种是在程序运行中,通过 attach api 来实现。

在介绍 agent 之前,先给大家简单说下 Instrumentation 。它是 JDK1.5 提供的 API ,用于拦截类加载事件,并对字节码进行修改,它的主要方法如下:

public interface Instrumentation {
    //注册一个转换器,类加载事件会被注册的转换器所拦截
     void addTransformer(ClassFileTransformer transformer, boolean canRetransform);
    //重新触发类加载
     void retransformClasses(Class<?>... classes) throws UnmodifiableClassException;
    //直接替换类的定义
     void redefineClasses(ClassDefinition... definitions) throws  ClassNotFoundException, UnmodifiableClassException;
}

premain

premain 是在 main 方法之前运行的方法,也是最常见的 agent 方式。运行时需要将 agent 程序打包成 jar 包,并在启动时添加命令来执行,如下文所示:

java -javaagent:agent.jar=xunche HelloWorld

premain 共提供以下 2 种重载方法, Jvm 启动时会先尝试使用第一种方法,若没有会使用第二种方法:

public static void premain(String agentArgs, Instrumentation inst);
public static void premain(String agentArgs);

一个简单的例子

下面我们通过一个程序来简单说明下 premain 的使用,首先我们准备下测试代码,测试代码比较简单,运行 main 方法并输出 hello world 。

public class HelloWorld {
    public static void main(String[] args) {
        System.out.println("Hello World");
    }
}

接下来我们看下 agent 的代码,运行 premain 方法并输出我们传入的参数。

public class HelloAgent {
  public static void premain(String args) {
    System.out.println("Hello Agent:  " + args);
  }
}

为了能够 agent 能够运行,我们需要将 META-INF/MANIFEST.MF 文件中的 Premain- Class 为我们编写的 agent 路径,然后通过以下方式将其打包成 jar 包,当然你也可以使用 idea 直接导出 jar 包。

echo 'Premain-Class: org.xunche.agent.HelloAgent' > manifest.mf
javac org/xunche/agent/HelloAgent.java
javac org/xunche/app/HelloWorld.java
jar cvmf manifest.mf hello-agent.jar org/

接下来,我们编译下并运行下测试代码,这里为了测试简单,我将编译后的 class 和 agent 的 jar 包放在了同级目录下

java -javaagent:hello-agent.jar=xunche org/xunche/app/HelloWorld

可以看到输出结果如下,agent中的premain方法有限于main方法执行

Hello Agent: xuncheHello World

稍微复杂点的例子

通过上面的例子,是否对 agent 有个简单的了解呢?

下面我们来看个稍微复杂点,我们通过 agent 来实现一个方法监控的功能。思路大致是这样的,若是非 jdk 的方法,我们通过 asm 在方法的执行入口和执行出口处,植入几行记录时间戳的代码,当方法结束后,通过时间戳来获取方法的耗时。

首先还是下测试代码,逻辑很简单, main 方法执行时调用 sayHi 方法,输出 hi , xunche ,并随机睡眠一段时间。

public class HelloXunChe {
    public static void main(String[] args) throws InterruptedException {
        HelloXunChe helloXunChe = new HelloXunChe();
        helloXunChe.sayHi();
    }
    public void sayHi() throws InterruptedException {
        System.out.println("hi, xunche");
        sleep();
    }
    public void sleep() throws InterruptedException {
        Thread.sleep((long) (Math.random() * 200));
    }
}

接下来我们借助 asm 来植入我们自己的代码,在 jvm 加载类的时候,为类的每个方法加上统计方法调用耗时的代码,代码如下,这里的 asm 我使用了 jdk 自带的,当然你也可以使用官方的 asm 类库。

import jdk.internal.org.objectweb.asm.*;
import jdk.internal.org.objectweb.asm.commons.AdviceAdapter;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.Instrumentation;
import java.security.ProtectionDomain;
public class TimeAgent {
    public static void premain(String args, Instrumentation instrumentation) {
        instrumentation.addTransformer(new TimeClassFileTransformer());
    }
    private static class TimeClassFileTransformer implements ClassFileTransformer {
        @Override
        public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) {
            if (className.startsWith("java") || className.startsWith("jdk") || className.startsWith("javax") || className.startsWith("sun") || className.startsWith("com/sun")|| className.startsWith("org/xunche/agent")) {
                //return null或者执行异常会执行原来的字节码
                return null;
            }
            System.out.println("loaded class: " + className);
            ClassReader reader = new ClassReader(classfileBuffer);
            ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
            reader.accept(new TimeClassVisitor(writer), ClassReader.EXPAND_FRAMES);
            return writer.toByteArray();
        }
    }
    public static class TimeClassVisitor extends ClassVisitor {
        public TimeClassVisitor(ClassVisitor classVisitor) {
            super(Opcodes.ASM5, classVisitor);
        }
        @Override
        public MethodVisitor visitMethod(int methodAccess, String methodName, String methodDesc, String signature, String[] exceptions) {
            MethodVisitor methodVisitor = cv.visitMethod(methodAccess, methodName, methodDesc, signature, exceptions);
            return new TimeAdviceAdapter(Opcodes.ASM5, methodVisitor, methodAccess, methodName, methodDesc);
        }
    }
    public static class TimeAdviceAdapter extends AdviceAdapter {
        private String methodName;
        protected TimeAdviceAdapter(int api, MethodVisitor methodVisitor, int methodAccess, String methodName, String methodDesc) {
            super(api, methodVisitor, methodAccess, methodName, methodDesc);
            this.methodName = methodName;
        }
        @Override
        protected void onMethodEnter() {
            //在方法入口处植入
            if ("<init>".equals(methodName)|| "<clinit>".equals(methodName)) {
                return;
            }
            mv.visitTypeInsn(NEW, "java/lang/StringBuilder");
            mv.visitInsn(DUP);
            mv.visitMethodInsn(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "()V", false);
            mv.visitVarInsn(ALOAD, 0);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "getClass", "()Ljava/lang/Class;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Class", "getName", "()Ljava/lang/String;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitLdcInsn(".");
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitLdcInsn(methodName);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;", false);
            mv.visitMethodInsn(INVOKESTATIC, "org/xunche/agent/TimeHolder", "start", "(Ljava/lang/String;)V", false);
        }
        @Override
        protected void onMethodExit(int i) {
            //在方法出口植入
            if ("<init>".equals(methodName) || "<clinit>".equals(methodName)) {
                return;
            }
            mv.visitTypeInsn(NEW, "java/lang/StringBuilder");
            mv.visitInsn(DUP);
            mv.visitMethodInsn(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "()V", false);
            mv.visitVarInsn(ALOAD, 0);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "getClass", "()Ljava/lang/Class;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Class", "getName", "()Ljava/lang/String;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitLdcInsn(".");
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitLdcInsn(methodName);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;", false);
            mv.visitVarInsn(ASTORE, 1);
            mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
            mv.visitTypeInsn(NEW, "java/lang/StringBuilder");
            mv.visitInsn(DUP);
            mv.visitMethodInsn(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "()V", false);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitLdcInsn(": ");
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitMethodInsn(INVOKESTATIC, "org/xunche/agent/TimeHolder", "cost", "(Ljava/lang/String;)J", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(J)Ljava/lang/StringBuilder;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;", false);
            mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
        }
    }
}

上述的代码略长, asm 的部分可以略过。我们通过 instrumentation.addTransformer 注册一个转换器,转换器重写了 transform 方法,方法入参中的 classfileBuffer 表示的是原始的字节码,方法返回值表示的是真正要进行加载的字节码。

onMethodEnter 方法中的代码含义是调用 TimeHolder 的 start 方法并传入当前的方法名。

onMethodExit 方法中的代码含义是调用 TimeHolder 的 cost 方法并传入当前的方法名,并打印 cost 方法的返回值。

来看下 TimeHolder 的代码:

package org.xunche.agent;
import java.util.HashMap;
import java.util.Map;
public class TimeHolder {
    private static Map<String, Long> timeCache = new HashMap<>();
    public static void start(String method) {
        timeCache.put(method, System.currentTimeMillis());
    }
    public static long cost(String method) {
        return System.currentTimeMillis() - timeCache.get(method);
    }
}

至此,agent 的代码编写完成,有关 asm 的部分不是本章的重点,日后再单独推出一篇有关 asm 的文章。通过在类加载时植入我们监控的代码后,下面我们来看看,经过 asm 修改后的代码是怎样的。可以看到,与最开始的测试代码相比,每个方法都加入了我们统计方法耗时的代码。

package org.xunche.app;
import org.xunche.agent.TimeHolder;
public class HelloXunChe {
    public HelloXunChe() {
    }
    public static void main(String[] args) throws InterruptedException {
        TimeHolder.start(args.getClass().getName() + "." + "main");
        HelloXunChe helloXunChe = new HelloXunChe();
        helloXunChe.sayHi();
        HelloXunChe helloXunChe = args.getClass().getName() + "." + "main";
        System.out.println(helloXunChe + ": " + TimeHolder.cost(helloXunChe));
    }
    public void sayHi() throws InterruptedException {
        TimeHolder.start(this.getClass().getName() + "." + "sayHi");
        System.out.println("hi, xunche");
        this.sleep();
        String var1 = this.getClass().getName() + "." + "sayHi";
        System.out.println(var1 + ": " + TimeHolder.cost(var1));
    }
    public void sleep() throws InterruptedException {
        TimeHolder.start(this.getClass().getName() + "." + "sleep");
        Thread.sleep((long)(Math.random() * 200.0D));
        String var1 = this.getClass().getName() + "." + "sleep";
        System.out.println(var1 + ": " + TimeHolder.cost(var1));
    }
}

agentmain

上面的 premain 是通过 agetn 在应用启动前,对字节码进行修改,来实现我们想要的功能。实际上 jdk 提供了 attach api ,通过这个 api ,我们可以访问已经启动的 Java 进程。并通过 agentmain 方法来拦截类加载。下面我们来通过实战来具体说明下 agentmain 。

实战

本次实战的目标是实现一个小工具,其目标是能远程采集已经处于运行中的 Java 进程的方法调用信息。听起来像不像 BTrace ,实际上 BTrace 也是这么实现的。只不过因为时间关系,本次的实战代码写的比较简陋,大家不必关注细节,看下实现的思路就好。

具体的实现思路如下:

  • agent 对指定类的方法进行字节码的修改,采集方法的入参和返回值。并通过 socket 将请求和返回发送到服务端
  • 服务端通过 attach api 访问运行中的 Java 进程,并加载 agent ,使 agent 程序能对目标进程生效
  • 服务端加载 agent 时指定需要采集的类和方法
  • 服务端开启一个端口,接受目标进程的请求信息

老规矩,先看测试代码,测试代码很简单,每隔 100ms 运行一次 sayHi 方法,并随机随眠一段时间。

package org.xunche.app;
public class HelloTraceAgent {
    public static void main(String[] args) throws InterruptedException {
        HelloTraceAgent helloTraceAgent = new HelloTraceAgent();
        while (true) {
            helloTraceAgent.sayHi("xunche");
            Thread.sleep(100);
        }
    }
    public String sayHi(String name) throws InterruptedException {
        sleep();
        String hi = "hi, " + name + ", " + System.currentTimeMillis();
        return hi;
    }
    public void sleep() throws InterruptedException {
        Thread.sleep((long) (Math.random() * 200));
    }
}

接下看 agent 代码,思路同监控方法耗时差不多,在方法出口处,通过 asm 植入采集方法入参和返回值的代码,并通过 Sender 将信息通过 socket 发送到服务端,代码如下:

package org.xunche.agent;
import jdk.internal.org.objectweb.asm.*;
import jdk.internal.org.objectweb.asm.commons.AdviceAdapter;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.security.ProtectionDomain;
public class TraceAgent {
    public static void agentmain(String args, Instrumentation instrumentation) throws ClassNotFoundException, UnmodifiableClassException {
        if (args == null) {
            return;
        }
        int index = args.lastIndexOf(".");
        if (index != -1) {
            String className = args.substring(0, index);
            String methodName = args.substring(index + 1);
            //目标代码已经加载,需要重新触发加载流程,才会通过注册的转换器进行转换
            instrumentation.addTransformer(new TraceClassFileTransformer(className.replace(".", "/"), methodName), true);
            instrumentation.retransformClasses(Class.forName(className));
        }
    }
    public static class TraceClassFileTransformer implements ClassFileTransformer {
        private String traceClassName;
        private String traceMethodName;
        public TraceClassFileTransformer(String traceClassName, String traceMethodName) {
            this.traceClassName = traceClassName;
            this.traceMethodName = traceMethodName;
        }
        @Override
        public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) {
            //过滤掉Jdk、agent、非指定类的方法
            if (className.startsWith("java") || className.startsWith("jdk") || className.startsWith("javax") || className.startsWith("sun")
                    || className.startsWith("com/sun") || className.startsWith("org/xunche/agent") || !className.equals(traceClassName)) {
                //return null会执行原来的字节码
                return null;
            }
            ClassReader reader = new ClassReader(classfileBuffer);
            ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
            reader.accept(new TraceVisitor(className, traceMethodName, writer), ClassReader.EXPAND_FRAMES);
            return writer.toByteArray();
        }
    }
    public static class TraceVisitor extends ClassVisitor {
        private String className;
        private String traceMethodName;
        public TraceVisitor(String className, String traceMethodName, ClassVisitor classVisitor) {
            super(Opcodes.ASM5, classVisitor);
            this.className = className;
            this.traceMethodName = traceMethodName;
        }
        @Override
        public MethodVisitor visitMethod(int methodAccess, String methodName, String methodDesc, String signature, String[] exceptions) {
            MethodVisitor methodVisitor = cv.visitMethod(methodAccess, methodName, methodDesc, signature, exceptions);
            if (traceMethodName.equals(methodName)) {
                return new TraceAdviceAdapter(className, methodVisitor, methodAccess, methodName, methodDesc);
            }
            return methodVisitor;
        }
    }
    private static class TraceAdviceAdapter extends AdviceAdapter {
        private final String className;
        private final String methodName;
        private final Type[] methodArgs;
        private final String[] parameterNames;
        private final int[] lvtSlotIndex;
        protected TraceAdviceAdapter(String className, MethodVisitor methodVisitor, int methodAccess, String methodName, String methodDesc) {
            super(Opcodes.ASM5, methodVisitor, methodAccess, methodName, methodDesc);
            this.className = className;
            this.methodName = methodName;
            this.methodArgs = Type.getArgumentTypes(methodDesc);
            this.parameterNames = new String[this.methodArgs.length];
            this.lvtSlotIndex = computeLvtSlotIndices(isStatic(methodAccess), this.methodArgs);
        }
        @Override
        public void visitLocalVariable(String name, String description, String signature, Label start, Label end, int index) {
            for (int i = 0; i < this.lvtSlotIndex.length; ++i) {
                if (this.lvtSlotIndex[i] == index) {
                    this.parameterNames[i] = name;
                }
            }
        }
        @Override
        protected void onMethodExit(int opcode) {
            //排除构造方法和静态代码块
            if ("<init>".equals(methodName) || "<clinit>".equals(methodName)) {
                return;
            }
            if (opcode == RETURN) {
                push((Type) null);
            } else if (opcode == LRETURN || opcode == DRETURN) {
                dup2();
                box(Type.getReturnType(methodDesc));
            } else {
                dup();
                box(Type.getReturnType(methodDesc));
            }
            Type objectType = Type.getObjectType("java/lang/Object");
            push(lvtSlotIndex.length);
            newArray(objectType);
            for (int j = 0; j < lvtSlotIndex.length; j++) {
                int index = lvtSlotIndex[j];
                Type type = methodArgs[j];
                dup();
                push(j);
                mv.visitVarInsn(ALOAD, index);
                box(type);
                arrayStore(objectType);
            }
            visitLdcInsn(className.replace("/", "."));
            visitLdcInsn(methodName);
            mv.visitMethodInsn(INVOKESTATIC, "org/xunche/agent/Sender", "send", "(Ljava/lang/Object;[Ljava/lang/Object;Ljava/lang/String;Ljava/lang/String;)V", false);
        }
        private static int[] computeLvtSlotIndices(boolean isStatic, Type[] paramTypes) {
            int[] lvtIndex = new int[paramTypes.length];
            int nextIndex = isStatic ? 0 : 1;
            for (int i = 0; i < paramTypes.length; ++i) {
                lvtIndex[i] = nextIndex;
                if (isWideType(paramTypes[i])) {
                    nextIndex += 2;
                } else {
                    ++nextIndex;
                }
            }
            return lvtIndex;
        }
        private static boolean isWideType(Type aType) {
            return aType == Type.LONG_TYPE || aType == Type.DOUBLE_TYPE;
        }
        private static boolean isStatic(int access) {
            return (access & 8) > 0;
        }
    }
}

以上就是 agent 的代码, onMethodExit 方法中的代码含义是获取请求参数和返回参数并调用 Sender.send 方法。这里的访问本地变量表的代码参考了 Spring 的 LocalVariableTableParameterNameDiscoverer ,感兴趣的同学可以自己研究下。接下来看

public class Sender {
    private static final int SERVER_PORT = 9876;
    public static void send(Object response, Object[] request, String className, String methodName) {
        Message message = new Message(response, request, className, methodName);
        try {
            Socket socket = new Socket("localhost", SERVER_PORT);
            socket.getOutputStream().write(message.toString().getBytes());
            socket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    private static class Message {
        private Object response;
        private Object[] request;
        private String className;
        private String methodName;
        public Message(Object response, Object[] request, String className, String methodName) {
            this.response = response;
            this.request = request;
            this.className = className;
            this.methodName = methodName;
        }
        @Override
        public String toString() {
            return "Message{" +
                    "response=" + response +
                    ", request=" + Arrays.toString(request) +
                    ", className='" + className + '\'' +
                    ", methodName='" + methodName + '\'' +
                    '}';
        }
    }
}

Sender 中的代码不复杂,一看就懂,就不多说了。下面我们来看下服务端的代码,服务端要实现开启一个端口监听,接受请求信息,以及使用 attach api 加载 agent 。

package org.xunche.app;
import com.sun.tools.attach.AgentInitializationException;
import com.sun.tools.attach.AgentLoadException;
import com.sun.tools.attach.AttachNotSupportedException;
import com.sun.tools.attach.VirtualMachine;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.ServerSocket;
import java.net.Socket;
public class TraceAgentMain {
    private static final int SERVER_PORT = 9876;
    public static void main(String[] args) throws IOException, AttachNotSupportedException, AgentLoadException, AgentInitializationException {
        new Server().start();
        //attach的进程
        VirtualMachine vm = VirtualMachine.attach("85241");
        //加载agent并指明需要采集信息的类和方法
        vm.loadAgent("trace-agent.jar", "org.xunche.app.HelloTraceAgent.sayHi");
        vm.detach();
    }
    private static class Server implements Runnable {
        @Override
        public void run() {
            try {
                ServerSocket serverSocket = new ServerSocket(SERVER_PORT);
                while (true) {
                    Socket socket = serverSocket.accept();
                    InputStream input = socket.getInputStream();
                    BufferedReader reader = new BufferedReader(new InputStreamReader(input));
                    System.out.println("receive message:" + reader.readLine());
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        public void start() {
            Thread thread = new Thread(this);
            thread.start();
        }
    }
}

运行上面的程序,可以看到服务端收到了 org.xunche.app.HelloTraceAgent.sayHi 的请求和返回信息。

receive message:Message{response=hi, xunche, 1581599464436, request=[xunche], className='org.xunche.app.HelloTraceAgent', methodName='sayHi'}

小结

在本章内容中,为大家介绍了 agent 的基本使用包括 premain 和 agentmain 。并通过 agentmain 实现了一个采集运行时方法调用信息的小工具,当然由于篇幅和时间问题,代码写的比较随意,大家多体会体会思路。实际上, agent 的作用远不止文章中介绍的这些,像 springloaded 等中也都有用到 agent 。

posted @ 2020-06-24 10:29  崔安兵  阅读(707)  评论(0编辑  收藏  举报