diff --git a/pom.xml b/pom.xml index 4a86f09..e14e232 100644 --- a/pom.xml +++ b/pom.xml @@ -25,6 +25,16 @@ 4.12 test + + com.h2database + h2 + 1.4.195 + + + org.apache.sshd + sshd-core + 1.7.0 + diff --git a/src/deploy/java/testperfix/Main.java b/src/deploy/java/testperfix/Main.java index 5881828..816c023 100644 --- a/src/deploy/java/testperfix/Main.java +++ b/src/deploy/java/testperfix/Main.java @@ -1,18 +1,17 @@ package testperfix; -import perfix.Registry; - +import java.sql.*; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) { - System.out.print("start me with -javaagent:target/agent-0.1-SNAPSHOT.jar"); - System.out.println(" and preferrably: -Dperfix.excludes=com,java,sun,org"); - Runtime.getRuntime().addShutdownHook(new Thread(() -> Registry.report())); + System.out.println("Start me with -javaagent:target/agent-0.1-SNAPSHOT.jar -Dperfix.includes=testperfix"); + System.out.println("Then start putty (or other telnet client) and telnet to localhost:2048"); run(); } public static void run() { + someJdbcMethod(); someOtherMethod(); try { TimeUnit.SECONDS.sleep(1); @@ -21,6 +20,20 @@ public class Main { } } + private static void someJdbcMethod() { + try { + Class.forName("org.h2.Driver"); + Connection connection = DriverManager.getConnection("jdbc:h2:mem:default", "sa", ""); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery("select CURRENT_DATE()"); + while (resultSet.next()){ + System.out.println("today is "+resultSet.getObject(1)); + } + } catch (ClassNotFoundException | SQLException e) { + e.printStackTrace(); + } + } + private static void someOtherMethod() { try { TimeUnit.NANOSECONDS.sleep(1); diff --git a/src/main/java/perfix/Agent.java b/src/main/java/perfix/Agent.java old mode 100755 new mode 100644 index 2acbecc..667e38b --- a/src/main/java/perfix/Agent.java +++ b/src/main/java/perfix/Agent.java @@ -1,14 +1,15 @@ package perfix; import javassist.*; +import perfix.server.SSHServer; import java.io.IOException; import java.lang.instrument.Instrumentation; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import static java.util.Arrays.asList; +import static java.util.Arrays.stream; public class Agent { @@ -18,16 +19,20 @@ public class Agent { private static final String DEFAULT_PORT = "2048"; private static final String MESSAGE = "Perfix agent active"; - private static final String PERFIX_METHOD_CLASS = "perfix.Method"; + private static final String PERFIX_METHODINVOCATION_CLASS = "perfix.MethodInvocation"; + + private static ClassPool classpool; public static void premain(String agentArgs, Instrumentation inst) { System.out.println(MESSAGE); int port = Integer.parseInt(System.getProperty(PORT_PROPERTY, DEFAULT_PORT)); + classpool = ClassPool.getDefault(); + instrumentCode(inst); - new Server().startListeningOnSocket(port); + new SSHServer().startListeningOnSocket(port); } private static void instrumentCode(Instrumentation inst) { @@ -38,28 +43,59 @@ public class Agent { } private static byte[] createByteCode(List includes, String resource, byte[] uninstrumentedByteCode) { - if (!isInnerClass(resource) && shouldInclude(resource, includes)) { + if (!isInnerClass(resource)) { try { - byte[] instrumentedBytecode = instrumentMethod(resource); - if (instrumentedBytecode != null) { - return instrumentedBytecode; + CtClass ctClass = getCtClassForResource(resource); + if (isJdbcStatement(resource, ctClass)) { + return instrumentJdbcCalls(ctClass); + } + if (shouldInclude(resource, includes)) { + byte[] instrumentedBytecode = instrumentMethod(ctClass); + if (instrumentedBytecode != null) { + return instrumentedBytecode; + } } } catch (Exception ex) { //suppress } + } return uninstrumentedByteCode; } - private static byte[] instrumentMethod(String resource) throws - NotFoundException, IOException, CannotCompileException { - ClassPool cp = ClassPool.getDefault(); - CtClass methodClass = cp.get(PERFIX_METHOD_CLASS); + private static byte[] instrumentJdbcCalls(CtClass classToInstrument) throws IOException, CannotCompileException { + try { + stream(classToInstrument.getDeclaredMethods("executeQuery")).forEach(m -> { + try { + m.insertBefore("System.out.println($1);"); + } catch (CannotCompileException e) { + e.printStackTrace(); + } + }); + } catch (Exception e) { + e.printStackTrace(); + } + byte[] byteCode = classToInstrument.toBytecode(); + classToInstrument.detach(); + return byteCode; + } + + private static boolean isJdbcStatement(String resource, CtClass ctClass) throws NotFoundException { + if (!resource.startsWith("java/sql")) { + return stream(ctClass.getInterfaces()) + .anyMatch(i -> i.getName().equals("java.sql.Statement") && !i.getName().equals("java.sql.PreparedStatement")); + } + return false; + } + + private static byte[] instrumentMethod(CtClass classToInstrument) throws + NotFoundException, IOException, CannotCompileException { + + CtClass perfixMethodInvocationClass = getCtClass(PERFIX_METHODINVOCATION_CLASS); - CtClass classToInstrument = cp.get(resource.replaceAll("/", ".")); if (!classToInstrument.isInterface()) { - Arrays.stream(classToInstrument.getDeclaredMethods()).forEach(m -> { - instrumentMethod(methodClass, m); + stream(classToInstrument.getDeclaredMethods()).forEach(m -> { + instrumentMethod(perfixMethodInvocationClass, m); }); byte[] byteCode = classToInstrument.toBytecode(); classToInstrument.detach(); @@ -69,11 +105,19 @@ public class Agent { } } - private static void instrumentMethod(CtClass methodClass, CtMethod m) { + private static CtClass getCtClassForResource(String resource) throws NotFoundException { + return getCtClass(resource.replaceAll("/", ".")); + } + + private static CtClass getCtClass(String classname) throws NotFoundException { + return classpool.get(classname); + } + + private static void instrumentMethod(CtClass methodClass, CtMethod methodToinstrument) { try { - m.addLocalVariable("perfixmethod", methodClass); - m.insertBefore("perfixmethod = perfix.Method.start(\"" + m.getLongName() + "\");"); - m.insertAfter("perfixmethod.stop();"); + methodToinstrument.addLocalVariable("perfixmethod", methodClass); + methodToinstrument.insertBefore("perfixmethod = perfix.MethodInvocation.start(\"" + methodToinstrument.getLongName() + "\");"); + methodToinstrument.insertAfter("perfixmethod.stop();"); } catch (CannotCompileException e) { throw new RuntimeException(e); } diff --git a/src/main/java/perfix/Method.java b/src/main/java/perfix/Method.java deleted file mode 100755 index 68a058f..0000000 --- a/src/main/java/perfix/Method.java +++ /dev/null @@ -1,29 +0,0 @@ -package perfix; - -public class Method { - private final long t0; - private final String name; - private long t1; - - private Method(String name) { - t0 = System.nanoTime(); - this.name = name; - } - - public static Method start(String name) { - return new Method(name); - } - - public void stop() { - t1 = System.nanoTime(); - Registry.add(this); - } - - public String getName() { - return name; - } - - long getDuration() { - return t1 - t0; - } -} diff --git a/src/main/java/perfix/Registry.java b/src/main/java/perfix/Registry.java old mode 100755 new mode 100644 index e517556..7e44d73 --- a/src/main/java/perfix/Registry.java +++ b/src/main/java/perfix/Registry.java @@ -8,14 +8,14 @@ import java.util.concurrent.atomic.LongAdder; public class Registry { - private static final Map> methods = new ConcurrentHashMap<>(); + private static final Map> methods = new ConcurrentHashMap<>(); private static final double NANO_2_MILLI = 1000000D; private static final String HEADER1 = "Invoked methods, by duration desc:"; - private static final String HEADER2 = "Method name;#Invocations;Total duration;Average Duration"; + private static final String HEADER2 = "MethodInvocation name;#Invocations;Total duration;Average Duration"; private static final String FOOTER = "----------------------------------------"; - static void add(Method method) { - methods.computeIfAbsent(method.getName(), key -> new ArrayList<>()).add(method); + static void add(MethodInvocation methodInvocation) { + methods.computeIfAbsent(methodInvocation.getName(), key -> new ArrayList<>()).add(methodInvocation); } public static void report(PrintStream out) { @@ -25,6 +25,7 @@ public class Registry { .map(entry -> createReportLine(entry.getValue())) .forEach(out::println); out.println(FOOTER); + out.flush(); } private static String createReportLine(Report report) { diff --git a/src/main/java/perfix/server/SSHServer.java b/src/main/java/perfix/server/SSHServer.java new file mode 100644 index 0000000..23c4ef2 --- /dev/null +++ b/src/main/java/perfix/server/SSHServer.java @@ -0,0 +1,45 @@ +package perfix.server; + +import org.apache.sshd.common.PropertyResolverUtils; +import org.apache.sshd.server.ServerFactoryManager; +import org.apache.sshd.server.SshServer; +import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider; + +import java.io.IOException; +import java.nio.file.Paths; + +public class SSHServer implements Server, Runnable { + private static final String BANNER = "\n\nWelcome to Perfix!\n\n"; + + private int port; + + public void startListeningOnSocket(int port) { + this.port=port; + new Thread(this).start(); + } + + @Override + public void run() { + SshServer sshd = SshServer.setUpDefaultServer(); + + PropertyResolverUtils.updateProperty(sshd, ServerFactoryManager.WELCOME_BANNER, BANNER); + sshd.setPasswordAuthenticator((s, s1, serverSession) -> true); + sshd.setPort(port); + sshd.setShellFactory(new SshSessionFactory()); + sshd.setKeyPairProvider(new SimpleGeneratorHostKeyProvider(Paths.get("hostkey.ser"))); + + try { + sshd.start(); + } catch (IOException e) { + e.printStackTrace(); + } + + for (;;){ + try { + Thread.sleep(500); + } catch (Exception e) { + e.printStackTrace(); + } + } + } +} diff --git a/src/main/java/perfix/server/Server.java b/src/main/java/perfix/server/Server.java new file mode 100644 index 0000000..92b7c27 --- /dev/null +++ b/src/main/java/perfix/server/Server.java @@ -0,0 +1,5 @@ +package perfix.server; + +public interface Server { + void startListeningOnSocket(int port); +} diff --git a/src/main/java/perfix/server/SshSessionFactory.java b/src/main/java/perfix/server/SshSessionFactory.java new file mode 100644 index 0000000..dab7773 --- /dev/null +++ b/src/main/java/perfix/server/SshSessionFactory.java @@ -0,0 +1,20 @@ +package perfix.server; + +import org.apache.sshd.common.Factory; +import org.apache.sshd.server.Command; +import org.apache.sshd.server.CommandFactory; + +public class SshSessionFactory + implements CommandFactory, Factory { + + @Override + public Command createCommand(String command) { + return new SshSessionInstance(); + } + + @Override + public Command create() { + return createCommand("none"); + } + +} \ No newline at end of file diff --git a/src/main/java/perfix/server/SshSessionInstance.java b/src/main/java/perfix/server/SshSessionInstance.java new file mode 100644 index 0000000..f597858 --- /dev/null +++ b/src/main/java/perfix/server/SshSessionInstance.java @@ -0,0 +1,78 @@ +package perfix.server; + +import org.apache.sshd.server.Command; +import org.apache.sshd.server.Environment; +import org.apache.sshd.server.ExitCallback; +import perfix.Registry; + +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; + +public class SshSessionInstance implements Command, Runnable { + + private static final String ANSI_LOCAL_ECHO = "\u001B[12l"; + private static final String ANSI_NEWLINE_CRLF = "\u001B[20h"; + + private InputStream is; + private OutputStream os; + + private ExitCallback callback; + private Thread sshThread; + + @Override + public void start(Environment env) { + sshThread = new Thread(this, "EchoShell"); + sshThread.start(); + } + + @Override + public void run() { + try { + os.write("press [enter] for report or [q] to quit\n".getBytes()); + os.write((ANSI_LOCAL_ECHO + ANSI_NEWLINE_CRLF).getBytes()); + os.flush(); + + boolean exit = false; + while (!exit) { + char c = (char) is.read(); + if (c == 'q') { + exit = true; + } else if (c == '\n') { + Registry.report(new PrintStream(os)); + } + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + callback.onExit(0); + } + } + + @Override + public void destroy() { + sshThread.interrupt(); + } + + @Override + public void setErrorStream(OutputStream errOS) { + } + + @Override + public void setExitCallback(ExitCallback ec) { + callback = ec; + } + + @Override + public void setInputStream(InputStream is) { + this.is = is; + } + + @Override + public void setOutputStream(OutputStream os) { + this.os = os; + } + +} + + diff --git a/src/main/java/perfix/Server.java b/src/main/java/perfix/server/TelnetServer.java similarity index 91% rename from src/main/java/perfix/Server.java rename to src/main/java/perfix/server/TelnetServer.java index ae88e28..12223c8 100644 --- a/src/main/java/perfix/Server.java +++ b/src/main/java/perfix/server/TelnetServer.java @@ -1,4 +1,6 @@ -package perfix; +package perfix.server; + +import perfix.Registry; import java.io.BufferedReader; import java.io.IOException; @@ -7,8 +9,8 @@ import java.io.PrintStream; import java.net.ServerSocket; import java.net.Socket; -public class Server { - void startListeningOnSocket(int port) { +public class TelnetServer implements Server { + public void startListeningOnSocket(int port) { try { ServerSocket serverSocket = new ServerSocket(port); new Thread(() -> { diff --git a/src/test/java/TestMethod.java b/src/test/java/TestMethod.java index 0f7de7c..cd63de8 100644 --- a/src/test/java/TestMethod.java +++ b/src/test/java/TestMethod.java @@ -1,16 +1,16 @@ import org.junit.Test; -import perfix.Method; +import perfix.MethodInvocation; import perfix.Registry; public class TestMethod { @Test public void testAddMethodToRegistry() { - Method method = Method.start("somename"); + MethodInvocation method = MethodInvocation.start("somename"); method.stop(); - Method method2 = Method.start("somename"); + MethodInvocation method2 = MethodInvocation.start("somename"); method2.stop(); - Registry.report(); + Registry.report(System.out); } }