Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/app/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ APS_REDIS_PASSWORD=''
APS_REDIS_DATABASE=1
# Token
TOKEN_SECRET_KEY='1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk'
TOKEN_WHITE_LIST=[1]
10 changes: 7 additions & 3 deletions backend/app/api/v1/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from fastapi import APIRouter, Depends
from fastapi.security import OAuth2PasswordRequestForm

from backend.app.common.jwt import DependsUser
from backend.app.common.jwt import DependsUser, JwtAuthentication
from backend.app.common.redis import redis_client
from backend.app.common.response.response_schema import response_base
from backend.app.schemas.token import Token
from backend.app.schemas.user import Auth
Expand All @@ -27,6 +28,9 @@ async def user_login(obj: Auth):


@router.post('/logout', summary='用户登出', dependencies=[DependsUser])
async def user_logout():
# TODO: 加入 token 黑名单
async def user_logout(jwt: JwtAuthentication):
user_id = jwt.get('payload').get('sub')
token = jwt.get('token')
key = f'token:{user_id}:{token}'
await redis_client.delete(key)
return response_base.success()
40 changes: 33 additions & 7 deletions backend/app/common/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Annotated

from backend.app.common.exception.errors import AuthorizationError, TokenError
from backend.app.common.redis import redis_client
from backend.app.core.conf import settings
from backend.app.crud.crud_user import UserDao
from backend.app.database.db_mysql import CurrentSession
Expand Down Expand Up @@ -42,28 +43,34 @@ def password_verify(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)


def create_access_token(data: int | Any, expires_delta: timedelta | None = None) -> str:
async def create_access_token(sub: int | Any, data: dict, expires_delta: timedelta | None = None) -> str:
"""
Generate encryption token

:param sub: The subject/userid of the JWT
:param data: Data transferred to the token
:param expires_delta: Increased expiry time
:return:
"""
if expires_delta:
expires = datetime.utcnow() + expires_delta
expire_seconds = expires_delta.total_seconds()
else:
expires = datetime.utcnow() + timedelta(settings.TOKEN_EXPIRE_MINUTES)
to_encode = {'exp': expires, 'sub': str(data[0]), 'role_ids': str(data[1])}
encoded_jwt = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
return encoded_jwt
expire_seconds = settings.TOKEN_EXPIRE_MINUTES * 60
to_encode = {'exp': expires, 'sub': str(sub), **data}
token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
if sub not in settings.TOKEN_WHITE_LIST:
await redis_client.delete(f'token:{sub}:*')
key = f'token:{sub}:{token}'
await redis_client.setex(key, expire_seconds, token)
return token


async def get_current_user(db: CurrentSession, token: str = Depends(oauth2_schema)) -> User:
async def jwt_authentication(token: str = Depends(oauth2_schema)):
"""
Get the current user through tokens
JWT authentication

:param db:
:param token:
:return:
"""
Expand All @@ -73,8 +80,25 @@ async def get_current_user(db: CurrentSession, token: str = Depends(oauth2_schem
user_role = payload.get('role_ids')
if not user_id or not user_role:
raise TokenError
# 验证token是否有效
key = f'token:{user_id}:{token}'
valid_token = await redis_client.get(key)
if not valid_token:
raise TokenError
return {'payload': payload, 'token': token}
except (jwt.JWTError, ValidationError):
raise TokenError


async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authentication)) -> User:
"""
Get the current user through tokens

:param db:
:param data:
:return:
"""
user_id = data.get('payload').get('sub')
user = await UserDao.get_user_with_relation(db, user_id=user_id)
if not user:
raise TokenError
Expand All @@ -97,6 +121,8 @@ async def get_current_is_superuser(user: User = Depends(get_current_user)):
# User Annotated
CurrentUser = Annotated[User, Depends(get_current_user)]
CurrentSuperUser = Annotated[bool, Depends(get_current_is_superuser)]
# Token dependency injection
JwtAuthentication = Annotated[dict, Depends(jwt_authentication)]
# Permission dependency injection
DependsUser = Depends(get_current_user)
DependsSuperUser = Depends(get_current_is_superuser)
1 change: 1 addition & 0 deletions backend/app/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Settings(BaseSettings):

# Env Token
TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32))
TOKEN_WHITE_LIST: list[str] # 白名单用户ID,可多点登录

# FastAPI
TITLE: str = 'FastAPI'
Expand Down
4 changes: 2 additions & 2 deletions backend/app/services/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def swagger_login(form_data: OAuth2PasswordRequestForm):
# 获取最新用户信息
user = await UserDao.get_user_by_id(db, current_user.id)
# 创建token
access_token = jwt.create_access_token([user.id, user_role_ids])
access_token = await jwt.create_access_token(user.id, {'role_ids': user_role_ids})
return access_token, user

@staticmethod
Expand All @@ -48,7 +48,7 @@ async def login(obj: Auth):
await UserDao.update_user_login_time(db, obj.username)
user_role_ids = await UserDao.get_user_role_ids(db, current_user.id)
user = await UserDao.get_user_by_id(db, current_user.id)
access_token = jwt.create_access_token([user.id, user_role_ids])
access_token = await jwt.create_access_token(user.id, {'role_ids': user_role_ids})
return access_token, user

@staticmethod
Expand Down