Fork me on GitHub

【源码解析】阿里在线诊断工具greys源码

撸起袖子加油干!开门见山!

一、源码下载

下载代码:

git clone https://github.com/oldmanpushcart/greys-anatomy.git

二、源码分析

2.1 目录介绍

用Idea打开greys源代码,源码主要分为两个目录:

greys-core(这个是核心包,包括启动主函数)
  advisor - 定义了一些增强工具、适配器、监听器
  command - 定义了控制台中不同的命令格式
  exception - 定义了程序的一些自定义的异常信息
  manager - 定义了一些管理类,比如反射操作管理、时间片段管理
  server - 定义了命令交互的服务端和处理器,命令行就是和这个做交互
  textui - 定义了数据展示的格式信息,有类信息展示、方法信息展示、对象信息展示等
  util - 定义了程序的工具类信息,有计算耗时、加锁工具、日志工具、对象大小计算等

greys-agent(这个是代理的真正实现,这个模块就三个类:自定义类加载器,不同切入点的方法引用,代理启动类)

2.2 启动流程与源码介绍

程序启动的脚本在greys.sh的main函数中

# the main
main()
{

    while getopts "PUJC" ARG
    do
        case ${ARG} in
            P) OPTION_CHECK_PERMISSION=0;;
            U) OPTION_UPDATE_IF_NECESSARY=0;;
            J) OPTION_ATTACH_JVM=0;;
            C) OPTION_ACTIVE_CONSOLE=0;;
            ?) usage;exit 1;;
        esac
    done

    shift $((OPTIND-1));

    if [[ ${OPTION_CHECK_PERMISSION} -eq 1 ]]; then
        check_permission
    fi

    reset_for_env

    parse_arguments "${@}" \
        || exit_on_err 1 "$(usage)"

    if [[ ${OPTION_UPDATE_IF_NECESSARY} -eq 1 ]]; then
        update_if_necessary \
            || echo "update fail, ignore this update." 1>&2
    fi

    local greys_local_version=$(default $(get_local_version) ${DEFAULT_VERSION})

    if [[ ${greys_local_version} = ${DEFAULT_VERSION} ]]; then
        exit_on_err 1 "greys not found, please check your network."
    fi

    if [[ ${OPTION_ATTACH_JVM} -eq 1 ]]; then
        attach_jvm ${greys_local_version}\
            || exit_on_err 1 "attach to target jvm(${TARGET_PID}) failed."
    fi

    if [[ ${OPTION_ACTIVE_CONSOLE} -eq 1 ]]; then
        active_console ${greys_local_version}\
            || exit_on_err 1 "active console failed."
    fi

}


main "${@}"

PUJC这4个属性如果没有指定,默认都是1,也就是默认开启。从这里可以看出,程序启动的时候,会自动开始校验权限、检验更新(如果有需要)、attach包到JVM(开启服务端)、开启控制台(开启客户端)。

前面两个不是很重要,和源码关系不大,就不提了。

attach包到JVM(开启服务端):

# attach greys to target jvm
# $1 : greys_local_version
attach_jvm()
{
    local greys_lib_dir=${GREYS_LIB_DIR}/${1}/greys

    # if [ ${TARGET_IP} = ${DEFAULT_TARGET_IP} ]; then
    if [ ! -z ${TARGET_PID} ]; then
        ${JAVA_HOME}/bin/java \
            ${BOOT_CLASSPATH} ${JVM_OPTS} \
            -jar ${greys_lib_dir}/greys-core.jar \
                -pid ${TARGET_PID} \
                -target ${TARGET_IP}":"${TARGET_PORT} \
                -core "${greys_lib_dir}/greys-core.jar" \
                -agent "${greys_lib_dir}/greys-agent.jar"
    fi
}

这里可以看出运行的就是源码greys-core打包好的jar文件。

而greys-core.jar实际运行的就是com.github.ompc.greys.core.GreysLauncher这个类。

 

 

 com.github.ompc.greys.core.GreysLauncher构造函数:

    public GreysLauncher(String[] args) throws Exception {
        // 解析配置文件(包括进程号、目标IP:端口、core包信息、agent包信息)
        Configure configure = analyzeConfigure(args);
        // 加载agent
        attachAgent(configure);
    }

attachAgent方法(这里会把agent包attach到JVM上):

    /*
     * 加载Agent
     */
    private void attachAgent(Configure configure) throws Exception {

        final ClassLoader loader = Thread.currentThread().getContextClassLoader();
        final Class<?> vmdClass = loader.loadClass("com.sun.tools.attach.VirtualMachineDescriptor");
        final Class<?> vmClass = loader.loadClass("com.sun.tools.attach.VirtualMachine");

        Object attachVmdObj = null;
        for (Object obj : (List<?>) vmClass.getMethod("list", (Class<?>[]) null).invoke(null, (Object[]) null)) {
            if ((vmdClass.getMethod("id", (Class<?>[]) null).invoke(obj, (Object[]) null))
                    .equals(Integer.toString(configure.getJavaPid()))) {
                attachVmdObj = obj;
            }
        }

//        if (null == attachVmdObj) {
//            // throw new IllegalArgumentException("pid:" + configure.getJavaPid() + " not existed.");
//        }

        Object vmObj = null;
        try {
            if (null == attachVmdObj) { // 使用 attach(String pid) 这种方式
                vmObj = vmClass.getMethod("attach", String.class).invoke(null, "" + configure.getJavaPid());
            } else {
                vmObj = vmClass.getMethod("attach", vmdClass).invoke(null, attachVmdObj);
            }
            vmClass.getMethod("loadAgent", String.class, String.class).invoke(vmObj, configure.getGreysAgent(), configure.getGreysCore() + ";" + configure.toString());
        } finally {
            if (null != vmObj) {
                vmClass.getMethod("detach", (Class<?>[]) null).invoke(vmObj, (Object[]) null);
            }
        }

    }

agent包的AgentLauncher类有premain、agentmain两个方法,这两个都会调用agent包里自己实现的main方法:

    public static void premain(String args, Instrumentation inst) {
        main(args, inst);
    }

    public static void agentmain(String args, Instrumentation inst) {
        main(args, inst);
    }

而这个main方法就会启动GaServer这个服务端,实际上就是开了一个Socket服务端,不断接收控制台的命令输入,然后根据命令做对应的操作,返回操作结果:

    private static synchronized void main(final String args, final Instrumentation inst) {
        try {

            // 传递的args参数分两个部分:agentJar路径和agentArgs
            // 分别是Agent的JAR包路径和期望传递到服务端的参数
            final int index = args.indexOf(';');
            final String agentJar = args.substring(0, index);
            final String agentArgs = args.substring(index, args.length());

            // 将Spy添加到BootstrapClassLoader
            inst.appendToBootstrapClassLoaderSearch(
                    new JarFile(AgentLauncher.class.getProtectionDomain().getCodeSource().getLocation().getFile())
            );

            // 构造自定义的类加载器,尽量减少Greys对现有工程的侵蚀
            final ClassLoader agentLoader = loadOrDefineClassLoader(agentJar);

            // Configure类定义
            final Class<?> classOfConfigure = agentLoader.loadClass("com.github.ompc.greys.core.Configure");

            // GaServer类定义
            final Class<?> classOfGaServer = agentLoader.loadClass("com.github.ompc.greys.core.server.GaServer");

            // 反序列化成Configure类实例
            final Object objectOfConfigure = classOfConfigure.getMethod("toConfigure", String.class)
                    .invoke(null, agentArgs);

            // JavaPid
            final int javaPid = (Integer) classOfConfigure.getMethod("getJavaPid").invoke(objectOfConfigure);

            // 获取GaServer单例
            final Object objectOfGaServer = classOfGaServer
                    .getMethod("getInstance", int.class, Instrumentation.class)
                    .invoke(null, javaPid, inst);

            // gaServer.isBind()
            final boolean isBind = (Boolean) classOfGaServer.getMethod("isBind").invoke(objectOfGaServer);

            if (!isBind) {
                try {
                    classOfGaServer.getMethod("bind", classOfConfigure).invoke(objectOfGaServer, objectOfConfigure);
                } catch (Throwable t) {
                    classOfGaServer.getMethod("destroy").invoke(objectOfGaServer);
                    throw t;
                }

            }

        } catch (Throwable t) {
            t.printStackTrace();
        }

    }

GaServer会根据socket连接的情况进行消息的接收,如果连接存在,就一直监听接收消息:

    private void activeSelectorDaemon(final Selector selector, final Configure configure) {

        final ByteBuffer byteBuffer = ByteBuffer.allocate(BUFFER_SIZE);

        final Thread gaServerSelectorDaemon = new Thread("ga-selector-daemon") {
            @Override
            public void run() {

                while (!isInterrupted()
                        && isBind()) {

                    try {

                        while (selector.isOpen()
                                && selector.select() > 0) {
                            final Iterator<SelectionKey> it = selector.selectedKeys().iterator();
                            while (it.hasNext()) {
                                final SelectionKey key = it.next();
                                it.remove();

                                // do ssc accept
                                if (key.isValid() && key.isAcceptable()) {
                                    doAccept(key, selector, configure);
                                }

                                // do sc read
                                if (key.isValid() && key.isReadable()) {
                                    doRead(byteBuffer, key);
                                }

                            }
                        }

                    } catch (IOException e) {
                        logger.warn("selector failed.", e);
                    } catch (ClosedSelectorException e) {
                        logger.debug("selector closed.", e);
                    }


                }

            }
        };
        gaServerSelectorDaemon.setDaemon(true);
        gaServerSelectorDaemon.start();
    }

GaServer收到消息之后会交给DefaultCommandHandler处理,包括数据校验、会话维护、命令执行等:

package com.github.ompc.greys.core.server;

import com.github.ompc.greys.core.advisor.AdviceListener;
import com.github.ompc.greys.core.advisor.AdviceWeaver;
import com.github.ompc.greys.core.advisor.Enhancer;
import com.github.ompc.greys.core.advisor.InvokeTraceable;
import com.github.ompc.greys.core.command.Command;
import com.github.ompc.greys.core.command.Command.Action;
import com.github.ompc.greys.core.command.Command.GetEnhancerAction;
import com.github.ompc.greys.core.command.Command.Printer;
import com.github.ompc.greys.core.command.Commands;
import com.github.ompc.greys.core.command.QuitCommand;
import com.github.ompc.greys.core.command.ShutdownCommand;
import com.github.ompc.greys.core.exception.CommandException;
import com.github.ompc.greys.core.exception.CommandInitializationException;
import com.github.ompc.greys.core.exception.CommandNotFoundException;
import com.github.ompc.greys.core.exception.GaExecuteException;
import com.github.ompc.greys.core.util.GaStringUtils;
import com.github.ompc.greys.core.util.LogUtil;
import com.github.ompc.greys.core.util.affect.Affect;
import com.github.ompc.greys.core.util.affect.EnhancerAffect;
import com.github.ompc.greys.core.util.affect.RowAffect;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;

import java.io.IOException;
import java.lang.instrument.Instrumentation;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.nio.charset.Charset;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.github.ompc.greys.core.util.GaCheckUtils.$;
import static com.github.ompc.greys.core.util.GaCheckUtils.$$;
import static com.github.ompc.greys.core.util.GaStringUtils.ABORT_MSG;
import static com.github.ompc.greys.core.util.GaStringUtils.getCauseMessage;
import static java.lang.String.format;
import static java.nio.ByteBuffer.wrap;
import static org.apache.commons.lang3.StringUtils.isBlank;

/**
 * 命令处理器
 * Created by oldmanpushcart@gmail.com on 15/5/2.
 */
public class DefaultCommandHandler implements CommandHandler {

    private final Logger logger = LogUtil.getLogger();

    private final GaServer gaServer;
    private final Instrumentation inst;

    public DefaultCommandHandler(GaServer gaServer, Instrumentation inst) {
        this.gaServer = gaServer;
        this.inst = inst;
    }

    @Override
    public void executeCommand(final String line, final Session session) throws IOException {

        final SocketChannel socketChannel = session.getSocketChannel();

        // 只有输入了有效字符才进行命令解析
        // 否则仅仅重绘提示符
        if (isBlank(line)) {

            // 这里因为控制不好,造成了输出两次提示符的问题
            // 第一次是因为这里,第二次则是下边(命令结束重绘提示符)
            // 这里做了一次取巧,虽然依旧是重绘了两次提示符,但在提示符之间增加了\r
            // 这样两次重绘都是在同一个位置,这样就没有人能发现,其实他们是被绘制了两次
            logger.debug("reDrawPrompt for blank line.");
            reDrawPrompt(session, socketChannel, session.getCharset(), session.prompt());
            return;
        }

        // don't ask why
        if ($(line)) {
            write(socketChannel, wrap($$()));
            reDrawPrompt(session, socketChannel, session.getCharset(), session.prompt());
            return;
        }

        try {
            final Command command = Commands.getInstance().newCommand(line);
            execute(session, command);

            // 退出命令,需要关闭会话
            if (command instanceof QuitCommand) {
                session.destroy();
            }

            // 关闭命令,需要关闭整个服务端
            else if (command instanceof ShutdownCommand) {
                DefaultCommandHandler.this.gaServer.destroy();
            }

            // 其他命令需要重新绘制提示符
            else {
                logger.debug("reDrawPrompt for command execute finished.");
                reDrawPrompt(session, socketChannel, session.getCharset(), session.prompt());
            }

        }

        // 命令准备错误(参数校验等)
        catch (CommandException t) {

            final String message;
            if (t instanceof CommandNotFoundException) {
                message = format("Command \"%s\" not found.", t.getCommand());
            } else if (t instanceof CommandInitializationException) {
                message = format("Command \"%s\" failed to initiate.", t.getCommand());
            } else {
                message = format("Command \"%s\" preprocessor failed : %s.", t.getCommand(), getCauseMessage(t));
            }

            write(socketChannel, message + "\n", session.getCharset());
            reDrawPrompt(session, socketChannel, session.getCharset(), session.prompt());

            logger.info(message, t);

        }

        // 命令执行错误
        catch (GaExecuteException e) {
            logger.warn("command execute failed.", e);

            final String cause = GaStringUtils.getCauseMessage(e);
            if (StringUtils.isNotBlank(cause)) {
                write(socketChannel, format("Command execution failed. cause : %s\n", cause), session.getCharset());
            } else {
                write(socketChannel, "Command execution failed.\n", session.getCharset());
            }

            reDrawPrompt(session, socketChannel, session.getCharset(), session.prompt());
        }

    }


    /*
     * 执行命令
     */
    private void execute(final Session session, final Command command) throws GaExecuteException, IOException {

        // 是否结束输入引用
        final AtomicBoolean isFinishRef = new AtomicBoolean(false);

        // 消息发送者
        final Printer printer = new Printer() {

            @Override
            public Printer print(boolean isF, String message) {

                if(isFinishRef.get()) {
                    return this;
                }

                final BlockingQueue<String> writeQueue = session.getWriteQueue();
                if (null != message) {
                    if (!writeQueue.offer(message)) {
                        logger.warn("offer message failed. write-queue.size() was {}", writeQueue.size());
                    }
                }

                if (isF) {
                    finish();
                }

                return this;

            }

            @Override
            public Printer println(boolean isF, String message) {
                return print(isF, message + "\n");
            }

            @Override
            public Printer print(String message) {
                return print(false, message);
            }

            @Override
            public Printer println(String message) {
                return println(false, message);
            }

            @Override
            public void finish() {
                isFinishRef.set(true);
            }

        };


        try {

            // 影响反馈
            final Affect affect;
            final Action action = command.getAction();

            // 无任何后续动作的动作
            if (action instanceof Command.SilentAction) {
                affect = new Affect();
                ((Command.SilentAction) action).action(session, inst, printer);
            }

            // 需要反馈行影响范围的动作
            else if (action instanceof Command.RowAction) {
                affect = new RowAffect();
                final RowAffect rowAffect = ((Command.RowAction) action).action(session, inst, printer);
                ((RowAffect) affect).rCnt(rowAffect.rCnt());
            }

            // 需要做类增强的动作
            else if (action instanceof GetEnhancerAction) {

                affect = new EnhancerAffect();

                // 执行命令动作 & 获取增强器
                final Command.GetEnhancer getEnhancer = ((GetEnhancerAction) action).action(session, inst, printer);
                final int lock = session.getLock();
                final AdviceListener listener = getEnhancer.getAdviceListener();
                final EnhancerAffect enhancerAffect = Enhancer.enhance(
                        inst,
                        lock,
                        listener instanceof InvokeTraceable,
                        getEnhancer.getPointCut()
                );

                // 这里做个补偿,如果在enhance期间,unLock被调用了,则补偿性放弃
                if (session.getLock() == lock) {
                    // 注册通知监听器
                    AdviceWeaver.reg(lock, listener);

                    if (!session.isSilent()) {
                        printer.println(ABORT_MSG);
                    }

                    ((EnhancerAffect) affect).cCnt(enhancerAffect.cCnt());
                    ((EnhancerAffect) affect).mCnt(enhancerAffect.mCnt());
                    ((EnhancerAffect) affect).getClassDumpFiles().addAll(enhancerAffect.getClassDumpFiles());
                }
            }

            // 其他自定义动作
            else {
                // do nothing...
                affect = new Affect();
            }

            if (!session.isSilent()) {
                // 记录下命令执行的执行信息
                printer.print(false, affect.toString() + "\n");
            }

        }

        // 命令执行错误必须纪录
        catch (Throwable t) {
            throw new GaExecuteException(format("execute failed. sessionId=%s", session.getSessionId()), t);
        }

        // 跑任务
        jobRunning(session, isFinishRef);

    }

    private void jobRunning(Session session, AtomicBoolean isFinishRef) throws IOException, GaExecuteException {

        final Thread currentThread = Thread.currentThread();
        final BlockingQueue<String> writeQueue = session.getWriteQueue();
        try {

            while (!session.isDestroy()
                    && !currentThread.isInterrupted()
                    && session.isLocked()) {

                // touch the session
                session.touch();

                try {
                    final String segment = writeQueue.poll(200, TimeUnit.MILLISECONDS);

                    // 如果返回的片段为null,说明当前没有东西可写
                    if (null == segment) {

                        // 当读到EOF的时候,同时Sender标记为isFinished
                        // 说明整个命令结束了,标记整个会话为不可写,结束命令循环
                        if (isFinishRef.get()) {
                            session.unLock();
                            break;
                        }

                    }

                    // 读出了点东西
                    else {
                        write(session.getSocketChannel(), segment, session.getCharset());
                    }

                } catch (InterruptedException e) {
                    currentThread.interrupt();
                }

            }//while command running

        }

        // 遇到关闭的链接可以忽略
        catch (ClosedChannelException e) {
            logger.debug("session[{}] write failed, because socket broken.",
                    session.getSessionId(), e);
        }

    }

    /*
     * 绘制提示符
     */
    private void reDrawPrompt(Session session, SocketChannel socketChannel, Charset charset, String prompt) throws IOException {
        if (!session.isSilent()) {
            write(socketChannel, prompt, charset);
        }
    }

    /*
     * 输出到网络
     */
    private void write(SocketChannel socketChannel, String message, Charset charset) throws IOException {
        write(socketChannel, charset.encode(message));
    }

    private void write(SocketChannel socketChannel, ByteBuffer buffer) throws IOException {
        while (buffer.hasRemaining()
                && socketChannel.isConnected()) {
            if (-1 == socketChannel.write(buffer)) {
                // socket broken
                throw new IOException("write EOF");
            }
        }
    }

}

开启控制台(开启客户端)

从grey.sh脚本可以看出启动控制台的入口类是com.github.ompc.greys.core.GreysConsole:

# active console
# $1 : greys_local_version
active_console()
{

    local greys_lib_dir=${GREYS_LIB_DIR}/${1}/greys

    if type ${JAVA_HOME}/bin/java 2>&1 >> /dev/null; then

        # use default console
        ${JAVA_HOME}/bin/java \
            -cp ${greys_lib_dir}/greys-core.jar \
            com.github.ompc.greys.core.GreysConsole \
                ${TARGET_IP} \
                ${TARGET_PORT}

    elif type telnet 2>&1 >> /dev/null; then

        # use telnet
        telnet ${TARGET_IP} ${TARGET_PORT}

    elif type nc 2>&1 >> /dev/null; then

        # use netcat
        nc ${TARGET_IP} ${TARGET_PORT}

    else

        echo "'telnet' or 'nc' is required." 1>&2
        return 1

    fi
}
com.github.ompc.greys.core.GreysConsole构造函数会开启一个读线程去获取用户在控制台输入的命令:
    public GreysConsole(InetSocketAddress address) throws IOException {

        this.console = initConsoleReader();
        this.history = initHistory();

        this.out = console.getOutput();
        this.history.moveToEnd();
        this.console.setHistoryEnabled(true);
        this.console.setHistory(history);
        this.console.setExpandEvents(false);
        this.socket = connect(address);

        // 关闭会话静默
        disableSilentOfSession();

        // 初始化自动补全
        initCompleter();

        this.isRunning = true;
        activeConsoleReader();


        socketWriter.write("version\n");
        socketWriter.flush();

        loopForWriter();

    }
然后发送给GaServer,得到GaServer的相应之后会回显在控制台上
    /**
     * 激活读线程
     */
    private void activeConsoleReader() {
        final Thread socketThread = new Thread("ga-console-reader-daemon") {

            private StringBuilder lineBuffer = new StringBuilder();

            @Override
            public void run() {
                try {

                    while (isRunning) {

                        final String line = console.readLine();

                        // 如果是\结尾,则说明还有下文,需要对换行做特殊处理
                        if (StringUtils.endsWith(line, "\\")) {
                            // 去掉结尾的\
                            lineBuffer.append(line.substring(0, line.length() - 1));
                            continue;
                        } else {
                            lineBuffer.append(line);
                        }

                        final String lineForWrite = lineBuffer.toString();
                        lineBuffer = new StringBuilder();

                        // replace ! to \!
                        // history.add(StringUtils.replace(lineForWrite, "!", "\\!"));

                        // flush if need
                        if (history instanceof Flushable) {
                            ((Flushable) history).flush();
                        }

                        console.setPrompt(EMPTY);
                        if (isNotBlank(lineForWrite)) {
                            socketWriter.write(lineForWrite + "\n");
                        } else {
                            socketWriter.write("\n");
                        }
                        socketWriter.flush();

                    }
                } catch (IOException e) {
                    err("read fail : %s", e.getMessage());
                    shutdown();
                }

            }

        };
        socketThread.setDaemon(true);
        socketThread.start();
    }

整体的交互流程就是这样。

2.3 核心增强代码

接着是一些核心的增强代码:

DefaultCommandHandler的execute方法的增强部分代码:

            // 需要做类增强的动作
            else if (action instanceof GetEnhancerAction) {

                affect = new EnhancerAffect();

                // 执行命令动作 & 获取增强器
                final Command.GetEnhancer getEnhancer = ((GetEnhancerAction) action).action(session, inst, printer);
                final int lock = session.getLock();
                final AdviceListener listener = getEnhancer.getAdviceListener();
                final EnhancerAffect enhancerAffect = Enhancer.enhance(
                        inst,
                        lock,
                        listener instanceof InvokeTraceable,
                        getEnhancer.getPointCut()
                );

                // 这里做个补偿,如果在enhance期间,unLock被调用了,则补偿性放弃
                if (session.getLock() == lock) {
                    // 注册通知监听器
                    AdviceWeaver.reg(lock, listener);

                    if (!session.isSilent()) {
                        printer.println(ABORT_MSG);
                    }

                    ((EnhancerAffect) affect).cCnt(enhancerAffect.cCnt());
                    ((EnhancerAffect) affect).mCnt(enhancerAffect.mCnt());
                    ((EnhancerAffect) affect).getClassDumpFiles().addAll(enhancerAffect.getClassDumpFiles());
                }
            }

Enhancer.enhance增强代码:

    /**
     * 对象增强
     *
     * @param inst      inst
     * @param adviceId  通知ID
     * @param isTracing 可跟踪方法调用
     * @param pointCut  增强点
     * @return 增强影响范围
     * @throws UnmodifiableClassException 增强失败
     */
    public static synchronized EnhancerAffect enhance(
            final Instrumentation inst,
            final int adviceId,
            final boolean isTracing,
            final PointCut pointCut) throws UnmodifiableClassException {

        final EnhancerAffect affect = new EnhancerAffect();


        final Map<Class<?>, Matcher<AsmMethod>> enhanceMap = toEnhanceMap(pointCut);

        // 构建增强器
        final Enhancer enhancer = new Enhancer(adviceId, isTracing, enhanceMap, affect);
        try {
            inst.addTransformer(enhancer, true);

            // 批量增强
            if (GlobalOptions.isBatchReTransform) {
                final int size = enhanceMap.size();
                final Class<?>[] classArray = new Class<?>[size];
                arraycopy(enhanceMap.keySet().toArray(), 0, classArray, 0, size);
                if (classArray.length > 0) {
                    inst.retransformClasses(classArray);
                }
            }


            // for each 增强
            else {
                for (Class<?> clazz : enhanceMap.keySet()) {
                    try {
                        inst.retransformClasses(clazz);
                    } catch (Throwable t) {
                        logger.warn("reTransform {} failed.", clazz, t);
                        if (t instanceof UnmodifiableClassException) {
                            throw (UnmodifiableClassException) t;
                        } else if (t instanceof RuntimeException) {
                            throw (RuntimeException) t;
                        } else {
                            throw new RuntimeException(t);
                        }
                    }
                }
            }


        } finally {
            inst.removeTransformer(enhancer);
        }

        return affect;
    }

Enhancer实现了ClassFileTransformer接口的transform方法:

    @Override
    public byte[] transform(
            final ClassLoader inClassLoader,
            final String className,
            final Class<?> classBeingRedefined,
            final ProtectionDomain protectionDomain,
            final byte[] classfileBuffer) throws IllegalClassFormatException {

        // 过滤掉不在增强集合范围内的类
        if (!enhanceMap.containsKey(classBeingRedefined)) {
            return null;
        }

        final ClassReader cr;

        // 首先先检查是否在缓存中存在Class字节码
        // 因为要支持多人协作,存在多人同时增强的情况
        final byte[] byteOfClassInCache = classBytesCache.get(classBeingRedefined);
        if (null != byteOfClassInCache) {
            cr = new ClassReader(byteOfClassInCache);
        }

        // 如果没有命中缓存,则从原始字节码开始增强
        else {
            cr = new ClassReader(classfileBuffer);
        }

        // 获取这个类所对应的asm方法匹配
        final Matcher<AsmMethod> asmMethodMatcher = enhanceMap.get(classBeingRedefined);

        // 字节码增强
        final ClassWriter cw = new ClassWriter(cr, COMPUTE_FRAMES | COMPUTE_MAXS) {

            /*
             * 注意,为了自动计算帧的大小,有时必须计算两个类共同的父类。
             * 缺省情况下,ClassWriter将会在getCommonSuperClass方法中计算这些,通过在加载这两个类进入虚拟机时,使用反射API来计算。
             * 但是,如果你将要生成的几个类相互之间引用,这将会带来问题,因为引用的类可能还不存在。
             * 在这种情况下,你可以重写getCommonSuperClass方法来解决这个问题。
             *
             * 通过重写 getCommonSuperClass() 方法,更正获取ClassLoader的方式,改成使用指定ClassLoader的方式进行。
             * 规避了原有代码采用Object.class.getClassLoader()的方式
             */
            @Override
            protected String getCommonSuperClass(String type1, String type2) {
                Class<?> c, d;
                try {
                    c = Class.forName(type1.replace('/', '.'), false, inClassLoader);
                    d = Class.forName(type2.replace('/', '.'), false, inClassLoader);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
                if (c.isAssignableFrom(d)) {
                    return type1;
                }
                if (d.isAssignableFrom(c)) {
                    return type2;
                }
                if (c.isInterface() || d.isInterface()) {
                    return "java/lang/Object";
                } else {
                    do {
                        c = c.getSuperclass();
                    } while (!c.isAssignableFrom(d));
                    return c.getName().replace('.', '/');
                }
            }

        };

        try {

            // 生成增强字节码
            cr.accept(new AdviceWeaver(adviceId, isTracing, cr.getClassName(), asmMethodMatcher, affect, cw), EXPAND_FRAMES);
            final byte[] enhanceClassByteArray = cw.toByteArray();

            // 生成成功,推入缓存
            classBytesCache.put(classBeingRedefined, enhanceClassByteArray);

            // dump the class
            dumpClassIfNecessary(className, enhanceClassByteArray, affect);

            // 成功计数
            affect.cCnt(1);

            // 排遣间谍
            try {
                spy(inClassLoader);
            } catch (Throwable t) {
                logger.warn("print spy failed. classname={};loader={};", className, inClassLoader, t);
                throw t;
            }

            return enhanceClassByteArray;
        } catch (Throwable t) {
            logger.warn("transform loader[{}]:class[{}] failed.", inClassLoader, className, t);
        }

        return null;
    }
// 生成增强字节码
cr.accept(new AdviceWeaver(adviceId, isTracing, cr.getClassName(), asmMethodMatcher, affect, cw), EXPAND_FRAMES);
这个AdviceWeaver继承了ClassVisitor重写了visitMethod方法:
   @Override
    public MethodVisitor visitMethod(
            final int access,
            final String name,
            final String desc,
            final String signature,
            final String[] exceptions) {

        final MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);

        if (isIgnore(mv, access, name, desc)) {
            return mv;
        }

        // 编织方法计数
        affect.mCnt(1);

        return new AdviceAdapter(ASM5, new JSRInlinerAdapter(mv, access, name, desc, signature, exceptions), access, name, desc) {

            // -- Lebel for try...catch block
            private final Label beginLabel = new Label();
            private final Label endLabel = new Label();

            // -- KEY of advice --
            private final int KEY_GREYS_ADVICE_BEFORE_METHOD = 0;
            private final int KEY_GREYS_ADVICE_RETURN_METHOD = 1;
            private final int KEY_GREYS_ADVICE_THROWS_METHOD = 2;
            private final int KEY_GREYS_ADVICE_BEFORE_INVOKING_METHOD = 3;
            private final int KEY_GREYS_ADVICE_AFTER_INVOKING_METHOD = 4;
            private final int KEY_GREYS_ADVICE_THROW_INVOKING_METHOD = 5;


            // -- KEY of ASM_TYPE or ASM_METHOD --
            private final Type ASM_TYPE_SPY = Type.getType("Lcom/github/ompc/greys/agent/Spy;");
            private final Type ASM_TYPE_OBJECT = Type.getType(Object.class);
            private final Type ASM_TYPE_OBJECT_ARRAY = Type.getType(Object[].class);
            private final Type ASM_TYPE_CLASS = Type.getType(Class.class);
            private final Type ASM_TYPE_INTEGER = Type.getType(Integer.class);
            private final Type ASM_TYPE_CLASS_LOADER = Type.getType(ClassLoader.class);
            private final Type ASM_TYPE_STRING = Type.getType(String.class);
            private final Type ASM_TYPE_THROWABLE = Type.getType(Throwable.class);
            private final Type ASM_TYPE_INT = Type.getType(int.class);
            private final Type ASM_TYPE_METHOD = Type.getType(java.lang.reflect.Method.class);
            private final Method ASM_METHOD_METHOD_INVOKE = Method.getMethod("Object invoke(Object,Object[])");

            // 代码锁
            private final CodeLock codeLockForTracing = new TracingAsmCodeLock(this);


            private void _debug(final StringBuilder append, final String msg) {

                if (!isDebugForAsm) {
                    return;
                }

                // println msg
                visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                if (StringUtils.isBlank(append.toString())) {
                    visitLdcInsn(append.append(msg).toString());
                } else {
                    visitLdcInsn(append.append(" >> ").append(msg).toString());
                }

                visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
            }

            /**
             * 加载通知方法
             * @param keyOfMethod 通知方法KEY
             */
            private void loadAdviceMethod(int keyOfMethod) {

                switch (keyOfMethod) {

                    case KEY_GREYS_ADVICE_BEFORE_METHOD: {
                        getStatic(ASM_TYPE_SPY, "ON_BEFORE_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    case KEY_GREYS_ADVICE_RETURN_METHOD: {
                        getStatic(ASM_TYPE_SPY, "ON_RETURN_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    case KEY_GREYS_ADVICE_THROWS_METHOD: {
                        getStatic(ASM_TYPE_SPY, "ON_THROWS_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    case KEY_GREYS_ADVICE_BEFORE_INVOKING_METHOD: {
                        getStatic(ASM_TYPE_SPY, "BEFORE_INVOKING_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    case KEY_GREYS_ADVICE_AFTER_INVOKING_METHOD: {
                        getStatic(ASM_TYPE_SPY, "AFTER_INVOKING_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    case KEY_GREYS_ADVICE_THROW_INVOKING_METHOD: {
                        getStatic(ASM_TYPE_SPY, "THROW_INVOKING_METHOD", ASM_TYPE_METHOD);
                        break;
                    }

                    default: {
                        throw new IllegalArgumentException("illegal keyOfMethod=" + keyOfMethod);
                    }

                }

            }

            /**
             * 加载ClassLoader<br/>
             * 这里分开静态方法中ClassLoader的获取以及普通方法中ClassLoader的获取
             * 主要是性能上的考虑
             */
            private void loadClassLoader() {

                if (this.isStaticMethod()) {

//                    // fast enhance
//                    if (GlobalOptions.isEnableFastEnhance) {
//                        visitLdcInsn(Type.getType(String.format("L%s;", internalClassName)));
//                        visitMethodInsn(INVOKEVIRTUAL, "java/lang/Class", "getClassLoader", "()Ljava/lang/ClassLoader;", false);
//                    }

                    // normal enhance
//                    else {

                    // 这里不得不用性能极差的Class.forName()来完成类的获取,因为有可能当前这个静态方法在执行的时候
                    // 当前类并没有完成实例化,会引起JVM对class文件的合法性校验失败
                    // 未来我可能会在这一块考虑性能优化,但对于当前而言,功能远远重要于性能,也就不打算折腾这么复杂了
                    visitLdcInsn(javaClassName);
                    invokeStatic(ASM_TYPE_CLASS, Method.getMethod("Class forName(String)"));
                    invokeVirtual(ASM_TYPE_CLASS, Method.getMethod("ClassLoader getClassLoader()"));
//                    }

                } else {
                    loadThis();
                    invokeVirtual(ASM_TYPE_OBJECT, Method.getMethod("Class getClass()"));
                    invokeVirtual(ASM_TYPE_CLASS, Method.getMethod("ClassLoader getClassLoader()"));
                }

            }

            /**
             * 加载before通知参数数组
             */
            private void loadArrayForBefore() {
                push(7);
                newArray(ASM_TYPE_OBJECT);

                dup();
                push(0);
                push(adviceId);
                box(ASM_TYPE_INT);
                arrayStore(ASM_TYPE_INTEGER);

                dup();
                push(1);
                loadClassLoader();
                arrayStore(ASM_TYPE_CLASS_LOADER);

                dup();
                push(2);
                push(tranClassName(javaClassName));
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(3);
                push(name);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(4);
                push(desc);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(5);
                loadThisOrPushNullIfIsStatic();
                arrayStore(ASM_TYPE_OBJECT);

                dup();
                push(6);
                loadArgArray();
                arrayStore(ASM_TYPE_OBJECT_ARRAY);
            }


            @Override
            protected void onMethodEnter() {

                codeLockForTracing.lock(new CodeLock.Block() {
                    @Override
                    public void code() {

                        final StringBuilder append = new StringBuilder();
                        _debug(append, "debug:onMethodEnter()");

                        // 加载before方法
                        loadAdviceMethod(KEY_GREYS_ADVICE_BEFORE_METHOD);
                        _debug(append, "loadAdviceMethod()");

                        // 推入Method.invoke()的第一个参数
                        pushNull();

                        // 方法参数
                        loadArrayForBefore();
                        _debug(append, "loadArrayForBefore()");

                        // 调用方法
                        invokeVirtual(ASM_TYPE_METHOD, ASM_METHOD_METHOD_INVOKE);
                        pop();
                        _debug(append, "invokeVirtual()");

                    }
                });

                mark(beginLabel);

            }


            /*
             * 加载return通知参数数组
             */
            private void loadReturnArgs() {
                dup2X1();
                pop2();
                push(2);
                newArray(ASM_TYPE_OBJECT);
                dup();
                dup2X1();
                pop2();
                push(0);
                swap();
                arrayStore(ASM_TYPE_OBJECT);

                dup();
                push(1);
                push(adviceId);
                box(ASM_TYPE_INT);
                arrayStore(ASM_TYPE_INTEGER);
            }

            @Override
            protected void onMethodExit(final int opcode) {

                if (!isThrow(opcode)) {
                    codeLockForTracing.lock(new CodeLock.Block() {
                        @Override
                        public void code() {

                            final StringBuilder append = new StringBuilder();
                            _debug(append, "debug:onMethodExit()");

                            // 加载返回对象
                            loadReturn(opcode);
                            _debug(append, "loadReturn()");

                            // 加载returning方法
                            loadAdviceMethod(KEY_GREYS_ADVICE_RETURN_METHOD);
                            _debug(append, "loadAdviceMethod()");

                            // 推入Method.invoke()的第一个参数
                            pushNull();

                            // 加载return通知参数数组
                            loadReturnArgs();
                            _debug(append, "loadReturnArgs()");

                            invokeVirtual(ASM_TYPE_METHOD, ASM_METHOD_METHOD_INVOKE);
                            pop();
                            _debug(append, "invokeVirtual()");

                        }
                    });
                }

            }


            /*
             * 创建throwing通知参数本地变量
             */
            private void loadThrowArgs() {
                dup2X1();
                pop2();
                push(2);
                newArray(ASM_TYPE_OBJECT);
                dup();
                dup2X1();
                pop2();
                push(0);
                swap();
                arrayStore(ASM_TYPE_THROWABLE);

                dup();
                push(1);
                push(adviceId);
                box(ASM_TYPE_INT);
                arrayStore(ASM_TYPE_INTEGER);
            }

            @Override
            public void visitMaxs(int maxStack, int maxLocals) {

                mark(endLabel);
                visitTryCatchBlock(beginLabel, endLabel, mark(), ASM_TYPE_THROWABLE.getInternalName());
                // catchException(beginLabel, endLabel, ASM_TYPE_THROWABLE);

                codeLockForTracing.lock(new CodeLock.Block() {
                    @Override
                    public void code() {

                        final StringBuilder append = new StringBuilder();
                        _debug(append, "debug:catchException()");

                        // 加载异常
                        loadThrow();
                        _debug(append, "loadAdviceMethod()");

                        // 加载throwing方法
                        loadAdviceMethod(KEY_GREYS_ADVICE_THROWS_METHOD);
                        _debug(append, "loadAdviceMethod()");

                        // 推入Method.invoke()的第一个参数
                        pushNull();

                        // 加载throw通知参数数组
                        loadThrowArgs();
                        _debug(append, "loadThrowArgs()");

                        // 调用方法
                        invokeVirtual(ASM_TYPE_METHOD, ASM_METHOD_METHOD_INVOKE);
                        pop();
                        _debug(append, "invokeVirtual()");

                    }
                });

                throwException();

                super.visitMaxs(maxStack, maxLocals);
            }

            /**
             * 是否静态方法
             * @return true:静态方法 / false:非静态方法
             */
            private boolean isStaticMethod() {
                return (methodAccess & ACC_STATIC) != 0;
            }

            /**
             * 是否抛出异常返回(通过字节码判断)
             * @param opcode 操作码
             * @return true:以抛异常形式返回 / false:非抛异常形式返回(return)
             */
            private boolean isThrow(int opcode) {
                return opcode == ATHROW;
            }

            /**
             * 将NULL推入堆栈
             */
            private void pushNull() {
                push((Type) null);
            }

            /**
             * 加载this/null
             */
            private void loadThisOrPushNullIfIsStatic() {
                if (isStaticMethod()) {
                    pushNull();
                } else {
                    loadThis();
                }
            }

            /**
             * 加载返回值
             * @param opcode 操作吗
             */
            private void loadReturn(int opcode) {
                switch (opcode) {

                    case RETURN: {
                        pushNull();
                        break;
                    }

                    case ARETURN: {
                        dup();
                        break;
                    }

                    case LRETURN:
                    case DRETURN: {
                        dup2();
                        box(Type.getReturnType(methodDesc));
                        break;
                    }

                    default: {
                        dup();
                        box(Type.getReturnType(methodDesc));
                        break;
                    }

                }
            }

            /**
             * 加载异常
             */
            private void loadThrow() {
                dup();
            }


            /**
             * 加载方法调用跟踪通知所需参数数组(for before/after)
             */
            private void loadArrayForInvokeBeforeOrAfterTracing(String owner, String name, String desc) {
                push(5);
                newArray(ASM_TYPE_OBJECT);

                dup();
                push(0);
                push(adviceId);
                box(ASM_TYPE_INT);
                arrayStore(ASM_TYPE_INTEGER);

                if (null != currentLineNumber) {
                    dup();
                    push(1);
                    push(currentLineNumber);
                    box(ASM_TYPE_INT);
                    arrayStore(ASM_TYPE_INTEGER);
                }

                dup();
                push(2);
                push(owner);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(3);
                push(name);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(4);
                push(desc);
                arrayStore(ASM_TYPE_STRING);
            }

            /**
             * 加载方法调用跟踪通知所需参数数组(for throw)
             */
            private void loadArrayForInvokeThrowTracing(String owner, String name, String desc) {
                push(6);
                newArray(ASM_TYPE_OBJECT);

                dup();
                push(0);
                push(adviceId);
                box(ASM_TYPE_INT);
                arrayStore(ASM_TYPE_INTEGER);


                if (null != currentLineNumber) {
                    dup();
                    push(1);
                    push(currentLineNumber);
                    box(ASM_TYPE_INT);
                    arrayStore(ASM_TYPE_INTEGER);
                }

                dup();
                push(2);
                push(owner);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(3);
                push(name);
                arrayStore(ASM_TYPE_STRING);

                dup();
                push(4);
                push(desc);
                arrayStore(ASM_TYPE_STRING);

                dup2(); // e,a,e,a
                swap(); // e,a,a,e
                invokeVirtual(ASM_TYPE_OBJECT, Method.getMethod("Class getClass()"));
                invokeVirtual(ASM_TYPE_CLASS, Method.getMethod("String getName()"));

                // e,a,a,s
                push(5); // e,a,a,s,4
                swap();  // e,a,a,4,s
                arrayStore(ASM_TYPE_STRING);

                // e,a
            }


            @Override
            public void visitInsn(int opcode) {
                super.visitInsn(opcode);
                codeLockForTracing.code(opcode);
            }


            /*
             * 跟踪代码
             */
            private void tracing(final int tracingType, final String owner, final String name, final String desc) {

                final String label;
                switch (tracingType) {
                    case KEY_GREYS_ADVICE_BEFORE_INVOKING_METHOD: {
                        label = "beforeInvoking";
                        break;
                    }
                    case KEY_GREYS_ADVICE_AFTER_INVOKING_METHOD: {
                        label = "afterInvoking";
                        break;
                    }
                    case KEY_GREYS_ADVICE_THROW_INVOKING_METHOD: {
                        label = "throwInvoking";
                        break;
                    }
                    default: {
                        throw new IllegalStateException("illegal tracing type: " + tracingType);
                    }
                }

                codeLockForTracing.lock(new CodeLock.Block() {
                    @Override
                    public void code() {

                        final StringBuilder append = new StringBuilder();
                        _debug(append, "debug:" + label + "()");

                        if (tracingType == KEY_GREYS_ADVICE_THROW_INVOKING_METHOD) {
                            loadArrayForInvokeThrowTracing(owner, name, desc);
                        } else {
                            loadArrayForInvokeBeforeOrAfterTracing(owner, name, desc);
                        }
                        _debug(append, "loadArrayForInvokeTracing()");

                        loadAdviceMethod(tracingType);
                        swap();
                        _debug(append, "loadAdviceMethod()");

                        pushNull();
                        swap();

                        invokeVirtual(ASM_TYPE_METHOD, ASM_METHOD_METHOD_INVOKE);
                        pop();
                        _debug(append, "invokeVirtual()");

                    }
                });

            }

            private Integer currentLineNumber;

            @Override
            public void visitLineNumber(int line, Label start) {
                super.visitLineNumber(line, start);
                currentLineNumber = line;
            }

            @Override
            public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc, final boolean itf) {

                if (!isTracing || codeLockForTracing.isLock()) {
                    super.visitMethodInsn(opcode, owner, name, desc, itf);
                    return;
                }

                // 方法调用前通知
                tracing(KEY_GREYS_ADVICE_BEFORE_INVOKING_METHOD, owner, name, desc);

                final Label beginLabel = new Label();
                final Label endLabel = new Label();
                final Label finallyLabel = new Label();

                // try
                // {

                mark(beginLabel);
                super.visitMethodInsn(opcode, owner, name, desc, itf);
                mark(endLabel);

                // 方法调用后通知
                tracing(KEY_GREYS_ADVICE_AFTER_INVOKING_METHOD, owner, name, desc);
                goTo(finallyLabel);

                // }
                // catch
                // {

                catchException(beginLabel, endLabel, ASM_TYPE_THROWABLE);
                tracing(KEY_GREYS_ADVICE_THROW_INVOKING_METHOD, owner, name, desc);

                throwException();

                // }
                // finally
                // {
                mark(finallyLabel);
                // }


            }

            // 用于try-catch的冲排序,目的是让tracing的try...catch能在exceptions tables排在前边
            private final Collection<AsmTryCatchBlock> asmTryCatchBlocks = new ArrayList<AsmTryCatchBlock>();

            @Override
            public void visitTryCatchBlock(Label start, Label end, Label handler, String type) {
                asmTryCatchBlocks.add(new AsmTryCatchBlock(start, end, handler, type));
            }

            @Override
            public void visitEnd() {
                for (AsmTryCatchBlock tcb : asmTryCatchBlocks) {
                    super.visitTryCatchBlock(tcb.start, tcb.end, tcb.handler, tcb.type);
                }
                super.visitEnd();
            }
        };

    }

在这里操作了类的字节码信息。

三、总结

这个源码重点在于如何代理,看的时候需要花费不少时间去理解。如果完成本次源码解析,可以收获不少的知识。谢谢观看!

DefaultCommandHandler
posted @ 2021-05-19 20:36  罗西施  阅读(327)  评论(0编辑  收藏  举报