Spring Security + JWT

By | 2024년 3월 16일
Table of Contents

Spring Security + JWT

Spring Security 와 JWT 를 이용해 API 접근제어를 구현해 봅니다.

파일의 수는 많지만 실제로는,
JwtAuthenticationFilter 를 호출하는 부분과 JwtAuthenticationFilter 의 내부 작동만 확인하면 Security 가 작동하는 방식을 대부분 확인할 수 있습니다.

그밖에 토큰을 생성하고 갱신하는 부분은 RestController 에서 직접 호출하므로 이해하기 쉽습니다.

주의!!!!
secret 만 있으면 누구나 토큰을 만들어낼 수 있습니다.
소스코드를 카피 페이스트 하는건 좋은데 secret 는 반드시 변경하세요.

또한, Brute Force 대응코드도 없습니다.
(github 에는 대응 코드 넣었습니다.)

github

https://github.com/skyer9/spring-boot-rest-api-example

의존성 추가

Spring Security, JWT, H2 를 설정해 줍니다.

dependencies {
    implementation 'org.springframework.boot:spring-boot-starter-security'
    implementation 'org.springframework.boot:spring-boot-starter-validation'
    testImplementation 'org.springframework.security:spring-security-test'
    implementation group: 'io.jsonwebtoken', name: 'jjwt-api', version: '0.11.5'
    runtimeOnly group: 'io.jsonwebtoken', name: 'jjwt-impl', version: '0.11.5'
    runtimeOnly group: 'io.jsonwebtoken', name: 'jjwt-jackson', version: '0.11.5'
}
spring:
  h2:
    console:
      enabled: true
  datasource:
    url: jdbc:h2:mem:testdb;NON_KEYWORDS=USER
    driver-class-name: org.h2.Driver
    username: sa
    password:
  jpa:
    open-in-view: false
    hibernate:
      ddl-auto: create-drop
    properties:
      hibernate:
        format_sql: true
        show_sql: ture
    defer-datasource-initialization: true

jwt:
  secret: a2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbXRva2FyaW10b2thcmltdG9rYXJpbQ==
  access-token-validity-in-milliseconds: 1800000
  refresh-token-validity-in-milliseconds: 86400000
-- insert into user (USERNAME, PASSWORD, NICKNAME, ACTIVATED)
-- values ('admin', '$2a$08$lDnHPz7eUkSi6ao14Twuau08mzhWrL4kyZGGU5xfiGALO/Vxd5DOi', 'admin', 1);

insert into AUTHORITY (AUTHORITY_NAME) values ('USER');
insert into AUTHORITY (AUTHORITY_NAME) values ('ADMIN');

-- insert into USER_AUTHORITY (USERNAME, AUTHORITY_NAME) values ('admin', 'USER');
-- insert into USER_AUTHORITY (USERNAME, AUTHORITY_NAME) values ('admin', 'ADMIN');

Util

ResponseDto.java

@Getter
@Setter
@Builder
public class ResponseDto<T> {
    private HttpStatus status;
    private String message;
    private T data;

    public static<T> ResponseDto<T> res(final HttpStatus status, final String message) {
        return res(status, message, null);
    }

    public static<T> ResponseDto<T> res(final HttpStatus status, final String message, final T data) {
        return ResponseDto.<T>builder()
                .status(status)
                .message(message)
                .data(data)
                .build();
    }
}

SecurityUtil.java

public class SecurityUtil {

    private SecurityUtil() {}

    public static Optional<String> getCurrentUsername() {

        final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();

        if (authentication == null) {
            return Optional.empty();
        }

        String username = null;
        if (authentication.getPrincipal() instanceof UserDetails springSecurityUser) {
            username = springSecurityUser.getUsername();
        } else if (authentication.getPrincipal() instanceof String) {
            username = (String) authentication.getPrincipal();
        }

        return Optional.ofNullable(username);
    }

    public static void setResponse(HttpServletResponse response, Exception ex) throws RuntimeException, IOException {
        ObjectMapper mapper = new ObjectMapper();
        response.setStatus(HttpServletResponse.SC_OK);
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.getWriter().print(mapper.writeValueAsString(ResponseDto.res(HttpStatus.BAD_REQUEST, ex.getMessage())));
    }

    public static String getClientIpAddress(HttpServletRequest request) {
        String ipAddress = request.getHeader("X-FORWARDED-FOR");
        if (ipAddress == null) {
            return request.getRemoteAddr();
        }
        return ipAddress.contains(",") ? ipAddress.split(",")[0] : ipAddress;
    }
}

Config 레이어

SecurityConfig.java 와 JwtTokenProvider.java 를 중점적으로 확인해 줍니다.

generateAccessToken(), generateRefreshToken() 가 실제로 JWT 토큰을 생성하는 로직입니다.

JwtAccessDeniedHandler.java

@Component
public class JwtAccessDeniedHandler implements AccessDeniedHandler {
    @Override
    public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) throws IOException {
        SecurityUtil.setResponse(response, accessDeniedException);
    }
}

JwtAuthenticationEntryPoint.java

@Component
public class JwtAuthenticationEntryPoint implements AuthenticationEntryPoint {
    @Override
    public void commence(HttpServletRequest request,
                         HttpServletResponse response,
                         AuthenticationException authException) throws IOException {
        SecurityUtil.setResponse(response, authException);
    }
}

JwtAuthenticationFilter.java

@RequiredArgsConstructor
public class JwtAuthenticationFilter extends GenericFilterBean {
    private final JwtTokenProvider jwtTokenProvider;

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        String token = resolveToken((HttpServletRequest) request);
        if (token != null && jwtTokenProvider.validateToken(token)) {
            Authentication authentication = jwtTokenProvider.getAuthentication(token);
            SecurityContextHolder.getContext().setAuthentication(authentication);
        }
        chain.doFilter(request, response);
    }

    private String resolveToken(HttpServletRequest request) {
        String bearerToken = request.getHeader(JwtTokenProvider.AUTHORIZATION_HEADER);
        if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer")) {
            return bearerToken.substring(7);
        }
        return null;
    }
}

JwtExceptionFilter.java

@Slf4j
@RequiredArgsConstructor
@Component
public class JwtExceptionFilter extends OncePerRequestFilter {

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException {
        try {
            chain.doFilter(request, response);
        } catch (JwtException ex) {
            SecurityUtil.setResponse(response, ex);
        }
    }
}

소스코드 라인수는 많은데 의외로 내용은 간단합니다.

JwtTokenProvider.java

@Slf4j
@Component
@RequiredArgsConstructor
public class JwtTokenProvider implements InitializingBean {
    private Key key;

    public static final String AUTHORITIES_KEY = "auth";
    public static final String AUTHORIZATION_HEADER = "Authorization";

    @Value("${jwt.secret}")
    private String secret;
    @Value("${jwt.access-token-validity-in-milliseconds}")
    private long accessTokenValidityInMilliseconds;
    @Value("${jwt.refresh-token-validity-in-milliseconds}")
    private long refreshTokenValidityInMilliseconds;

    private final RefreshTokenRepository refreshTokenRepository;

    @Transactional
    public TokenDto generateToken(MyUser myUser) {
        long now = (new Date()).getTime();
        String authorities = getAuthorities(myUser);
        String accessToken = generateAccessToken(myUser.getUsername(), authorities, now);
        String refreshToken = generateRefreshToken(myUser.getUsername(), now);

        return TokenDto.builder()
                .grantType("Bearer")
                .accessToken(accessToken)
                .refreshToken(refreshToken)
                .build();
    }

    public TokenDto reissueToken(MyUser myUser, RefreshToken refreshToken) {
        validateToken(refreshToken.getToken());

        long now = (new Date()).getTime();
        String authorities = getAuthorities(myUser);
        String accessToken = generateAccessToken(refreshToken.getUsername(), authorities, now);

        return TokenDto.builder()
                .grantType("Bearer")
                .accessToken(accessToken)
                .refreshToken(refreshToken.getToken())
                .build();
    }

    public Authentication getAuthentication(String accessToken) {
        Claims claims = parseClaims(accessToken);

        if (claims.get(AUTHORITIES_KEY) == null) {
            throw new RuntimeException("권한 정보가 없는 토큰입니다.");
        }

        Collection<? extends GrantedAuthority> authorities =
                Arrays.stream(claims.get(AUTHORITIES_KEY).toString().split(","))
                        .map(SimpleGrantedAuthority::new)
                        .collect(Collectors.toList());

        UserDetails principal = new User(claims.getSubject(), "", authorities);
        return new UsernamePasswordAuthenticationToken(principal, "", authorities);
    }

    public boolean validateToken(String token) {
        try {
            Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(token);
        } catch (io.jsonwebtoken.security.SecurityException | MalformedJwtException e) {
            throw new JwtException("Invalid JWT Token");
        } catch (ExpiredJwtException e) {
            throw new JwtException("Expired JWT Token");
        } catch (UnsupportedJwtException e) {
            throw new JwtException("Unsupported JWT Token");
        } catch (IllegalArgumentException e) {
            throw new JwtException("JWT claims string is empty");
        } catch (Exception e) {
            throw new JwtException("Unknown exception occurred", e);
        }
        return true;
    }

    private Claims parseClaims(String accessToken) {
        try {
            return Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(accessToken).getBody();
        } catch (ExpiredJwtException e) {
            return e.getClaims();
        }
    }

    private String generateAccessToken(String username, String authorities, long now) {
        Date accessTokenExpiresIn = new Date(now + accessTokenValidityInMilliseconds);
        return Jwts.builder()
                .setSubject(username)
                .claim(AUTHORITIES_KEY, authorities)
                .setExpiration(accessTokenExpiresIn)
                .signWith(key, SignatureAlgorithm.HS512)
                .compact();
    }

    private String generateRefreshToken(String username, long now) {
        String refreshToken = Jwts.builder()
                .setSubject(username)
                .setExpiration(new Date(now + refreshTokenValidityInMilliseconds))
                .signWith(key, SignatureAlgorithm.HS512)
                .compact();

        Optional<RefreshToken> saved = refreshTokenRepository.findByToken(refreshToken);
        if (saved.isEmpty()) {
            refreshTokenRepository.save(RefreshToken
                    .builder()
                    .token(refreshToken)
                    .expiryDate(now + refreshTokenValidityInMilliseconds)
                    .username(username)
                    .build());
        }

        return refreshToken;
    }

    private String getAuthorities(MyUser myUser) {
        return myUser
                .getAuthorities()
                .stream()
                .map(Authority::addPrefix)
                .collect(Collectors.joining(","));
    }

    @Override
    public void afterPropertiesSet() {
        byte[] keyBytes = Decoders.BASE64.decode(secret);
        this.key = Keys.hmacShaKeyFor(keyBytes);
    }
}

SecurityConfig.java

@Configuration
@EnableWebSecurity
@RequiredArgsConstructor
public class SecurityConfig {
    private final JwtTokenProvider jwtTokenProvider;
    private final JwtExceptionFilter jwtExceptionFilter;
    private final JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint;
    private final JwtAccessDeniedHandler jwtAccessDeniedHandler;

    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        http
                .csrf(AbstractHttpConfigurer::disable)
                .headers((headerConfig) ->
                        headerConfig.frameOptions(HeadersConfigurer.FrameOptionsConfig::disable)
                )
                .authorizeHttpRequests((authorizeRequests) ->
                        authorizeRequests
                                .requestMatchers(PathRequest.toH2Console()).permitAll()
                                .requestMatchers("/api/createAdminUser").permitAll()
                                .requestMatchers("/api/signin").permitAll()
                                .requestMatchers("/api/signup").permitAll()
                                .requestMatchers("/api/reissue").permitAll()
                                .requestMatchers("/favicon.ico").permitAll()
                                .requestMatchers("/swagger-ui/**").permitAll()
                                .requestMatchers("/v3/api-docs/**").permitAll()
                                .requestMatchers("/api/user").hasAnyRole("ADMIN", "USER")
                                .requestMatchers("/api/user/**").hasRole("ADMIN")
                                .anyRequest().authenticated()
                )
                .exceptionHandling((exceptionConfig) ->
                        exceptionConfig
                                .authenticationEntryPoint(jwtAuthenticationEntryPoint) // handle 401 Error
                                .accessDeniedHandler(jwtAccessDeniedHandler)           // handle 403 Error
                )
                .sessionManagement(configurer -> configurer.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
                .addFilterBefore(new JwtAuthenticationFilter(jwtTokenProvider), UsernamePasswordAuthenticationFilter.class)
                .addFilterBefore(jwtExceptionFilter, JwtAuthenticationFilter.class);
        return http.build();
    }

    @Bean
    public PasswordEncoder passwordEncoder() {
        return new BCryptPasswordEncoder();
    }
}

Domain 레이어

RefreshToken 을 데이타베이스에 저장하고 갱신요청이 들어올때 RefreshToken 의 유효성을 검증합니다.
따라서, AccessToken 이 탈취당해도 최대한 접근을 제한할 수단이 생깁니다.
(RefreshToken 삭제 등)

@Entity
@Getter
@Setter
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Table(name = "authority")
public class Authority {
    @Id
    @Column(name = "authority_name", length = 50)
    private String authorityName;

    public String addPrefix() {
        return this.authorityName.startsWith("ROLE_") ? authorityName : "ROLE_" + authorityName;
    }
}
@Entity
@Getter
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Table(name= "user")
public class MyUser {

    @Id
    @Column(name = "username", length = 128, unique = true)
    private String username;

    @JsonIgnore
    @Column(name = "password", length = 256)
    private String password;

    @Column(name = "nickname", length = 128)
    private String nickname;

    @JsonIgnore
    @Column(name = "activated")
    private boolean activated;

    @ManyToMany
    @Fetch(FetchMode.JOIN)
    @JoinTable(
            name = "user_authority",
            joinColumns = {@JoinColumn(name = "username", referencedColumnName = "username")},
            inverseJoinColumns = {@JoinColumn(name = "authority_name", referencedColumnName = "authority_name")})
    private Set<Authority> authorities;
}
public interface MyUserRepository extends JpaRepository<MyUser, String> {
    Optional<MyUser> findByUsername(String username);

    Optional<MyUser> findOneWithAuthoritiesByUsername(String username);
}
@Entity
@Getter
@Setter
@Builder
@AllArgsConstructor
@NoArgsConstructor
@Table(name= "refresh_token")
public class RefreshToken {

    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    @Column(unique = true)
    private String token;

    private Long expiryDate;
    private String username;
}
@Repository
public interface RefreshTokenRepository extends CrudRepository<RefreshToken, Long> {
    Optional<RefreshToken> findByToken(String token);
}

Service 레이어

@Getter
@Setter
@NoArgsConstructor
public class LoginDto {
    @NotNull
    @Size(min = 3, max = 128)
    private String username;

    @JsonProperty(access = JsonProperty.Access.WRITE_ONLY)
    @NotNull
    @Size(min = 3, max = 256)
    private String password;
}
@Getter
@Setter
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class MyUserDto {
    @NotNull
    @Size(min = 3, max = 128)
    private String username;

    @JsonProperty(access = JsonProperty.Access.WRITE_ONLY)
    @NotNull
    @Size(min = 3, max = 256)
    private String password;

    @NotNull
    @Size(min = 3, max = 128)
    private String nickname;
}
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class RefreshTokenRequestDTO {
    private String token;
}
@Getter
@Setter
@Builder
@AllArgsConstructor
public class TokenDto {
    private String grantType;
    private String accessToken;
    private String refreshToken;
}

아래 클래스를 구현해 주지 않으면 내부적으로 불필요한 로직이 추가로 작동하므로 생성해 줍니다.

@Service
@RequiredArgsConstructor
public class CustomUserDetailsService implements UserDetailsService {
    @Override
    public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException {
        return null;
    }
}
@Service
@RequiredArgsConstructor
public class MyUserService {
    private final MyUserRepository myUserRepository;
    private final JwtTokenProvider jwtTokenProvider;
    private final PasswordEncoder passwordEncoder;

    @Transactional(noRollbackFor = UserNotFoundException.class)
    public TokenDto login(HttpServletRequest request, String username, String password) {
        MyUser myUser = getUserWithAuthorities(username);
        if (!passwordEncoder.matches(password, myUser.getPassword())) {
            throw new UserNotFoundException("username or password is incorrect");
        }

        return jwtTokenProvider.generateToken(myUser);
    }

    @Transactional
    public TokenDto reissue(RefreshToken refreshToken) {
        MyUser myUser = getUserWithAuthorities(refreshToken.getUsername());
        return jwtTokenProvider.reissueToken(myUser, refreshToken);
    }

    @Transactional
    public MyUser createAdminUser(MyUserDto myUserDto) {
        if (myUserRepository.count() > 0) {
            throw new RuntimeException("User account must be zero");
        }

        Set<Authority> authorities = new HashSet<>();
        authorities.add(Authority.builder().authorityName("ADMIN").build());
        authorities.add(Authority.builder().authorityName("USER").build());

        MyUser myUser = MyUser.builder()
                .username(myUserDto.getUsername())
                .password(passwordEncoder.encode(myUserDto.getPassword()))
                .nickname(myUserDto.getNickname())
                .authorities(authorities)
                .activated(true)
                .build();

        return myUserRepository.save(myUser);
    }

    @Transactional
    public MyUser signup(MyUserDto myUserDto) {
        if (myUserRepository.findOneWithAuthoritiesByUsername(myUserDto.getUsername()).orElse(null) != null) {
            throw new RuntimeException("Unavailable username");
        }

        Authority authority = Authority.builder()
                .authorityName("USER")
                .build();

        MyUser myUser = MyUser.builder()
                .username(myUserDto.getUsername())
                .password(passwordEncoder.encode(myUserDto.getPassword()))
                .nickname(myUserDto.getNickname())
                .authorities(Collections.singleton(authority))
                .activated(true)
                .build();

        return myUserRepository.save(myUser);
    }

    @Transactional(readOnly = true)
    public MyUser getUserWithAuthorities(String username) {
        Optional<MyUser> myUser = myUserRepository.findOneWithAuthoritiesByUsername(username);
        if (myUser.isEmpty()) {
            throw new UserNotFoundException("User not found");
        }

        return myUser.get();
    }

    @Transactional(readOnly = true)
    public MyUser getMyUserWithAuthorities() {
        Optional<MyUser> myUser = SecurityUtil
                .getCurrentUsername()
                .flatMap(myUserRepository::findOneWithAuthoritiesByUsername);
        if (myUser.isEmpty()) {
            throw new UserNotFoundException("User not found");
        }

        return myUser.get();
    }
}
@Service
@RequiredArgsConstructor
public class RefreshTokenService {
    private final RefreshTokenRepository refreshTokenRepository;

    public Optional<RefreshToken> findByToken(String token){
        return refreshTokenRepository.findByToken(token);
    }
}

Controller 레이어

@RestControllerAdvice("com.example.api.web")
public class RestExceptionHandler {
    @ExceptionHandler(RuntimeException.class)
    public ResponseEntity<?> handleRuntimeException(RuntimeException e) {
        return ResponseEntity.ok(ResponseDto.res(HttpStatus.BAD_REQUEST, e.getMessage()));
    }

    @ExceptionHandler(UserNotFoundException.class)
    public ResponseEntity<?> handleUserNotFoundException(UserNotFoundException e) {
        return ResponseEntity.ok(ResponseDto.res(HttpStatus.BAD_REQUEST, e.getMessage()));
    }
}
public class UserNotFoundException extends RuntimeException {
    public UserNotFoundException(String msg) {
        super(msg);
    }
}
@RestController
@RequestMapping("/api")
@RequiredArgsConstructor
public class AuthController {
    private final MyUserService myUserService;
    private final RefreshTokenService refreshTokenService;

    @PostMapping("/createAdminUser")
    public ResponseEntity<MyUser> createAdminUser(
            @Valid @RequestBody MyUserDto userDto
    ) {
        return ResponseEntity.ok(myUserService.createAdminUser(userDto));
    }

    @PostMapping("/signin")
    public ResponseEntity<TokenDto> signin(@RequestBody LoginDto loginDto) {
        String username = loginDto.getUsername();
        String password = loginDto.getPassword();

        TokenDto tokenDto = myUserService.login(username, password);
        HttpHeaders httpHeaders = new HttpHeaders();
        httpHeaders.add(JwtTokenProvider.AUTHORIZATION_HEADER, "Bearer " + tokenDto.getAccessToken());

        return new ResponseEntity<>(tokenDto, httpHeaders, HttpStatus.OK);
    }

    @PostMapping("/reissue")
    public ResponseEntity<TokenDto> reissue(@RequestBody RefreshTokenRequestDTO requestDTO) {
        Optional<RefreshToken> refreshToken = refreshTokenService.findByToken(requestDTO.getToken());
        if (refreshToken.isEmpty()) {
            throw new RuntimeException("Token not found");
        }
        TokenDto tokenDto = myUserService.reissue(refreshToken.get());
        HttpHeaders httpHeaders = new HttpHeaders();
        httpHeaders.add(JwtTokenProvider.AUTHORIZATION_HEADER, "Bearer " + tokenDto.getAccessToken());

        return new ResponseEntity<>(tokenDto, httpHeaders, HttpStatus.OK);
    }
}
@Slf4j
@RestController
@RequiredArgsConstructor
@RequestMapping("/api")
public class MyUserController {
    private final MyUserService myUserService;

    @PostMapping("/signup")
    public ResponseEntity<MyUser> signup(
            @Valid @RequestBody MyUserDto userDto
    ) {
        return ResponseEntity.ok(myUserService.signup(userDto));
    }

    @GetMapping("/user")
    @PreAuthorize("hasAnyRole('USER','ADMIN')")
    public ResponseEntity<MyUser> getMyUserInfo() {
        return ResponseEntity.ok(myUserService.getMyUserWithAuthorities());
    }

    @GetMapping("/user/{username}")
    @PreAuthorize("hasAnyRole('ADMIN')")
    public ResponseEntity<MyUser> getUserInfo(@PathVariable String username) {
        return ResponseEntity.ok(myUserService.getUserWithAuthorities(username));
    }
}

테스트

POST http://localhost:8080/api/createAdminUser
{
  "username": "skyer9",
  "password": "abcd1234",
  "nickname": "skyer9"
}

# POST http://localhost:8080/api/signin
# {
#   "username": "admin",
#   "password": "admin"
# }

POST http://localhost:8080/api/signin
{
  "username": "skyer9",
  "password": "abcd1234"
}

POST http://localhost:8080/api/signup
{
  "username": "skyer9",
  "password": "abcd1234",
  "nickname": "skyer9"
}

POST http://localhost:8080/api/reissue
{
  "token": "refreshToken"
}

GET http://localhost:8080/api/user
Bearer accessToken

GET http://localhost:8080/api/user/{username}
Bearer accessToken

답글 남기기