@Override public Object answer(InvocationOnMock invocation) throws Throwable { TagInjectionOutputStream injectedStream = (TagInjectionOutputStream) invocation.getArguments()[1]; injectedStream.proxyLinked(dummyStreamProxy, linker); return dummyStreamProxy; } });
/** * Proxy for {@link javax.servlet.ServletResponse#getOutputStream()}. * * @return the instrumented stream * @throws IOException * if an exception getting the original stream occurs. */ @ProxyMethod(returnType = "javax.servlet.ServletOutputStream") public OutputStream getOutputStream() throws IOException { commitHeaderData(); if (wrappedStream == null) { OutputStream originalStream = wrappedResponse.getOutputStream(); // avoid rewrapping or unnecessary wrapping if (isNonHtmlContentTypeSet() || linker.isProxyInstance(originalStream, TagInjectionOutputStream.class)) { wrappedStream = originalStream; } else { TagInjectionOutputStream resultStr = new TagInjectionOutputStream(originalStream, tagToInject.printTags()); resultStr.setEncoding(wrappedResponse.getCharacterEncoding()); ClassLoader cl = wrappedResponse.getWrappedElement().getClass().getClassLoader(); wrappedStream = (OutputStream) linker.createProxy(TagInjectionOutputStream.class, resultStr, cl); if (wrappedStream == null) { // fallback to the normal stream if it can not be linked wrappedStream = originalStream; } } } return wrappedStream; }
@SuppressWarnings("unchecked") @Test public void testPlainTextNoInjection() throws IOException { ArgumentCaptor<TagInjectionOutputStream> streamCaptor = ArgumentCaptor.forClass(TagInjectionOutputStream.class); respWrapper.getOutputStream(); verify(linker, times(1)).createProxy(any(Class.class), streamCaptor.capture(), any(ClassLoader.class)); TagInjectionOutputStream stream = streamCaptor.getValue(); byte[] bytes = NON_HTML_TEST_CASE_A.getBytes(CHARACTER_ENCODING); int pos = 0; while (pos < bytes.length) { stream.write(bytes, pos, Math.min(3, bytes.length - pos)); pos += 3; // write 3 bytes at once } String result = new String(streamResult.toByteArray(), CHARACTER_ENCODING); assertThat(result, equalTo(NON_HTML_TEST_CASE_A)); }
@SuppressWarnings("unchecked") @Test public void testInvalidMarkupNoInjection() throws IOException { ArgumentCaptor<TagInjectionOutputStream> streamCaptor = ArgumentCaptor.forClass(TagInjectionOutputStream.class); respWrapper.getOutputStream(); verify(linker, times(1)).createProxy(any(Class.class), streamCaptor.capture(), any(ClassLoader.class)); TagInjectionOutputStream stream = streamCaptor.getValue(); byte[] bytes = NON_HTML_TEST_CASE_B.getBytes(CHARACTER_ENCODING); int pos = 0; while (pos < bytes.length) { stream.write(bytes, pos, Math.min(3, bytes.length - pos)); pos += 3; // write 3 bytes at once } String result = new String(streamResult.toByteArray(), CHARACTER_ENCODING); assertThat(result, equalTo(NON_HTML_TEST_CASE_B)); }
@SuppressWarnings("unchecked") @Test public void testBodyInjection() throws IOException { ArgumentCaptor<TagInjectionOutputStream> streamCaptor = ArgumentCaptor.forClass(TagInjectionOutputStream.class); respWrapper.getOutputStream(); verify(linker, times(1)).createProxy(any(Class.class), streamCaptor.capture(), any(ClassLoader.class)); TagInjectionOutputStream stream = streamCaptor.getValue(); byte[] bytes = HTML_TEST_CASE_B.getBytes(CHARACTER_ENCODING); int pos = 0; while (pos < bytes.length) { stream.write(bytes, pos, Math.min(3, bytes.length - pos)); pos += 3; // write 3 bytes at once } String result = new String(streamResult.toByteArray(), CHARACTER_ENCODING); assertThat(result, equalTo(HTML_TEST_CASE_B_REFERENCE)); }
@SuppressWarnings("unchecked") @Test public void testHeadInjection() throws IOException { ArgumentCaptor<TagInjectionOutputStream> streamCaptor = ArgumentCaptor.forClass(TagInjectionOutputStream.class); respWrapper.getOutputStream(); verify(linker, times(1)).createProxy(any(Class.class), streamCaptor.capture(), any(ClassLoader.class)); TagInjectionOutputStream stream = streamCaptor.getValue(); byte[] bytes = HTML_TEST_CASE_A.getBytes(CHARACTER_ENCODING); int pos = 0; while (pos < bytes.length) { stream.write(bytes, pos, Math.min(3, bytes.length - pos)); pos += 3; // write 3 bytes at once } String result = new String(streamResult.toByteArray(), CHARACTER_ENCODING); assertThat(result, equalTo(HTML_TEST_CASE_A_REFERENCE)); }