@Test public void doNotCauseSessionCreation() throws Exception { Map<String, Object> attributes = new HashMap<>(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertNull(this.servletRequest.getSession(false)); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { HttpSession session = getSession(request); if (session != null) { if (isCopyHttpSessionId()) { attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); } Enumeration<String> names = session.getAttributeNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (isCopyAllAttributes() || getAttributeNames().contains(name)) { attributes.put(name, session.getAttribute(name)); } } } return true; }
@Nullable private HttpSession getSession(ServerHttpRequest request) { if (request instanceof ServletServerHttpRequest) { ServletServerHttpRequest serverRequest = (ServletServerHttpRequest) request; return serverRequest.getServletRequest().getSession(isCreateSession()); } return null; }
@Test public void doNotCopyHttpSessionId() throws Exception { Map<String, Object> attributes = new HashMap<>(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); interceptor.setCopyHttpSessionId(false); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(1, attributes.size()); assertEquals("bar", attributes.get("foo")); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { return super.beforeHandshake(request, response, wsHandler, attributes); }
@Test public void doNotCopyAttributes() throws Exception { Map<String, Object> attributes = new HashMap<>(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); interceptor.setCopyAllAttributes(false); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(1, attributes.size()); assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); }
public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/other").withSockJS() .setInterceptors(new HttpSessionHandshakeInterceptor()); registry.addEndpoint("/chat").withSockJS() .setInterceptors(new HttpSessionHandshakeInterceptor()); }
@Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { super.afterHandshake(request, response, wsHandler, exception); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { System.out.println("连接成功之前的拦截器1"); return super.beforeHandshake(request, response, wsHandler, attributes); }
@Test public void handshakeHandlerAndInterceptor() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); assertEquals(1, mappings.size()); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); assertEquals(Arrays.asList("/foo"), entry.getValue()); WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); assertNotNull(requestHandler.getWebSocketHandler()); assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); assertEquals(2, requestHandler.getHandshakeInterceptors().size()); assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0)); assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass()); }
@Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) { super.afterHandshake(request, response, wsHandler, ex); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { InetSocketAddress remoteAddress = request.getRemoteAddress(); InetAddress address = remoteAddress.getAddress(); System.out.println(address); /* * 最后需要要显示调用父类方法,父类的beforeHandshake方法 * 把ServerHttpRequest 中session中对应的值拷贝到WebSocketSession中。 * 如果我们没有实现这个方法,我们在最后的handler处理中 是拿不到 session中的值 * 作为测试 可以注释掉下面这一行 可以发现自定义处理器中session的username总是为空 */ return super.beforeHandshake(request, response, wsHandler, attributes); } }
@Test public void constructorWithAttributeNames() throws Exception { Map<String, Object> attributes = new HashMap<>(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); this.servletRequest.getSession().setAttribute("bar", "baz"); Set<String> names = Collections.singleton("foo"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(2, attributes.size()); assertEquals("bar", attributes.get("foo")); assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { HttpSession session = getSession(request); if (session != null) { if (isCopyHttpSessionId()) { attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); } Enumeration<String> names = session.getAttributeNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (isCopyAllAttributes() || getAttributeNames().contains(name)) { attributes.put(name, session.getAttribute(name)); } } } return true; }
@Test public void handshakeHandlerAndInterceptorWithAllowedOrigins() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); String origin = "http://mydomain.com"; registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); assertEquals(1, mappings.size()); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); assertEquals(Arrays.asList("/foo"), entry.getValue()); WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); assertNotNull(requestHandler.getWebSocketHandler()); assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); assertEquals(2, requestHandler.getHandshakeInterceptors().size()); assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0)); assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass()); }
@Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) { //握手成功后,通常用来注册用户信息 logger.info("WebSocket Handshake After"); super.afterHandshake(request, response, wsHandler, ex); }
@Nullable private HttpSession getSession(ServerHttpRequest request) { if (request instanceof ServletServerHttpRequest) { ServletServerHttpRequest serverRequest = (ServletServerHttpRequest) request; return serverRequest.getServletRequest().getSession(isCreateSession()); } return null; }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) { logger.info("WebSocket Handshake Before: " + request.getURI()); try { if (request instanceof ServletServerHttpRequest) { ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; // 获取凭证ID,检查用户是否登录 String principalId = getPrincipalId(servletRequest.getServletRequest()); // 从Session中获取CookieId(无法从Header中获取) String token = LoginUserUtils.getTokenFromSession(servletRequest.getServletRequest()); logger.debug("WebSocket PrincipalId [{}]", principalId); logger.debug("WebSocket Token [{}]", token); attributes.put("principalId", principalId); attributes.put("token", token); return super.beforeHandshake(request, response, wsHandler, attributes); } } catch (LoginInvalidException e) { logger.warn("WebSocket Handshake Error(登录失效,握手失败)"); return false; } catch (Exception e) { logger.warn("WebSocket Handshake Error({})", e.getMessage()); logger.error(e.getMessage(), e); return false; } return false; }
@Test public void defaultConstructor() throws Exception { Map<String, Object> attributes = new HashMap<>(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); this.servletRequest.getSession().setAttribute("bar", "baz"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(3, attributes.size()); assertEquals("bar", attributes.get("foo")); assertEquals("baz", attributes.get("bar")); assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); }
@Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { HttpSession session = getSession(request); if (session != null) { if (isCopyHttpSessionId()) { attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); } Enumeration<String> names = session.getAttributeNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (isCopyAllAttributes() || getAttributeNames().contains(name)) { attributes.put(name, session.getAttribute(name)); } } } return true; }