Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/bootstrap/data/setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem {
{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
{Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL},
{Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL},
{Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL},

// single settings
{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
Expand Down
3 changes: 3 additions & 0 deletions internal/conf/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ const (
ForwardDirectLinkParams = "forward_direct_link_params"
IgnoreDirectLinkParams = "ignore_direct_link_params"
WebauthnLoginEnabled = "webauthn_login_enabled"
MaxDevices = "max_devices"
DeviceEvictPolicy = "device_evict_policy"
DeviceSessionTTL = "device_session_ttl"

// index
SearchIndex = "search_index"
Expand Down
2 changes: 1 addition & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var db *gorm.DB

func Init(d *gorm.DB) {
db = d
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile))
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session))
if err != nil {
log.Fatalf("failed migrate database: %s", err.Error())
}
Expand Down
65 changes: 65 additions & 0 deletions internal/db/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package db

import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/pkg/errors"
"gorm.io/gorm/clause"
)

func GetSession(userID uint, deviceKey string) (*model.Session, error) {
s := model.Session{UserID: userID, DeviceKey: deviceKey}
if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil {
return nil, errors.Wrap(err, "failed find session")
}
return &s, nil
}

func CreateSession(s *model.Session) error {
return errors.WithStack(db.Create(s).Error)
}

func UpsertSession(s *model.Session) error {
return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error)
}

func DeleteSession(userID uint, deviceKey string) error {
return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error)
}

func CountSessionsByUser(userID uint) (int64, error) {
var count int64
err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error
return count, errors.WithStack(err)
}

func DeleteSessionsBefore(ts int64) error {
return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error)
}

func GetOldestSession(userID uint) (*model.Session, error) {
var s model.Session
if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil {
return nil, errors.Wrap(err, "failed get oldest session")
}
return &s, nil
}

func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error {
return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error)
}

func ListSessionsByUser(userID uint) ([]model.Session, error) {
var sessions []model.Session
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error
return sessions, errors.WithStack(err)
}

func ListSessions() ([]model.Session, error) {
var sessions []model.Session
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error
return sessions, errors.WithStack(err)
}

func MarkInactive(sessionID string) error {
return errors.WithStack(db.Model(&model.Session{}).Where("device_key = ?", sessionID).Update("status", model.SessionInactive).Error)
}
67 changes: 67 additions & 0 deletions internal/device/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package device

import (
"time"

"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors"
"gorm.io/gorm"
)

// Handle verifies device sessions for a user and upserts current session.
func Handle(userID uint, deviceKey, ua, ip string) error {
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
if ttl > 0 {
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
}

ip = utils.MaskIP(ip)

now := time.Now().Unix()
sess, err := db.GetSession(userID, deviceKey)
if err == nil {
if sess.Status == model.SessionInactive {
return errors.WithStack(errs.SessionInactive)
}
sess.LastActive = now
sess.Status = model.SessionActive
sess.UserAgent = ua
sess.IP = ip
return db.UpsertSession(sess)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}

max := setting.GetInt(conf.MaxDevices, 0)
if max > 0 {
count, err := db.CountSessionsByUser(userID)
if err != nil {
return err
}
if count >= int64(max) {
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
if policy == "evict_oldest" {
oldest, err := db.GetOldestSession(userID)
if err == nil {
_ = db.DeleteSession(userID, oldest.DeviceKey)
}
} else {
return errors.WithStack(errs.TooManyDevices)
}
}
}

s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
return db.CreateSession(s)
}

// Refresh updates last_active for the session.
func Refresh(userID uint, deviceKey string) {
_ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix())
}
8 changes: 8 additions & 0 deletions internal/errs/device.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package errs

import "errors"

var (
TooManyDevices = errors.New("too many active devices")
SessionInactive = errors.New("session inactive")
)
16 changes: 16 additions & 0 deletions internal/model/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package model

// Session represents a device session of a user.
type Session struct {
UserID uint `json:"user_id" gorm:"index"`
DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"`
UserAgent string `json:"user_agent" gorm:"size:255"`
IP string `json:"ip" gorm:"size:64"`
LastActive int64 `json:"last_active"`
Status int `json:"status"`
}

const (
SessionActive = iota
SessionInactive
)
8 changes: 8 additions & 0 deletions internal/session/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package session

import "github.com/alist-org/alist/v3/internal/db"

// MarkInactive marks the session with the given ID as inactive.
func MarkInactive(sessionID string) error {
return db.MarkInactive(sessionID)
}
30 changes: 30 additions & 0 deletions pkg/utils/mask.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package utils

import "strings"

// MaskIP anonymizes middle segments of an IP address.
func MaskIP(ip string) string {
if ip == "" {
return ""
}
if strings.Contains(ip, ":") {
parts := strings.Split(ip, ":")
if len(parts) > 2 {
for i := 1; i < len(parts)-1; i++ {
if parts[i] != "" {
parts[i] = "*"
}
}
return strings.Join(parts, ":")
}
return ip
}
parts := strings.Split(ip, ".")
if len(parts) == 4 {
for i := 1; i < len(parts)-1; i++ {
parts[i] = "*"
}
return strings.Join(parts, ".")
}
return ip
}
8 changes: 8 additions & 0 deletions server/handles/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/session"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -247,6 +248,13 @@ func Verify2FA(c *gin.Context) {
}

func LogOut(c *gin.Context) {
if keyVal, ok := c.Get("device_key"); ok {
if err := session.MarkInactive(keyVal.(string)); err != nil {
common.ErrorResp(c, err, 500)
return
}
c.Set("session_inactive", true)
}
err := common.InvalidateToken(c.GetHeader("Authorization"))
if err != nil {
common.ErrorResp(c, err, 500)
Expand Down
92 changes: 92 additions & 0 deletions server/handles/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package handles

import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
)

type SessionResp struct {
SessionID string `json:"session_id"`
UserID uint `json:"user_id,omitempty"`
LastActive int64 `json:"last_active"`
Status int `json:"status"`
UA string `json:"ua"`
IP string `json:"ip"`
}

func ListMySessions(c *gin.Context) {
user := c.MustGet("user").(*model.User)
sessions, err := db.ListSessionsByUser(user.ID)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
resp := make([]SessionResp, len(sessions))
for i, s := range sessions {
resp[i] = SessionResp{
SessionID: s.DeviceKey,
LastActive: s.LastActive,
Status: s.Status,
UA: s.UserAgent,
IP: s.IP,
}
}
common.SuccessResp(c, resp)
}

type EvictSessionReq struct {
SessionID string `json:"session_id"`
}

func EvictMySession(c *gin.Context) {
var req EvictSessionReq
if err := c.ShouldBindJSON(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
user := c.MustGet("user").(*model.User)
if _, err := db.GetSession(user.ID, req.SessionID); err != nil {
common.ErrorResp(c, err, 400)
return
}
if err := db.MarkInactive(req.SessionID); err != nil {
common.ErrorResp(c, err, 500)
return
}
common.SuccessResp(c)
}

func ListSessions(c *gin.Context) {
sessions, err := db.ListSessions()
if err != nil {
common.ErrorResp(c, err, 500)
return
}
resp := make([]SessionResp, len(sessions))
for i, s := range sessions {
resp[i] = SessionResp{
SessionID: s.DeviceKey,
UserID: s.UserID,
LastActive: s.LastActive,
Status: s.Status,
UA: s.UserAgent,
IP: s.IP,
}
}
common.SuccessResp(c, resp)
}

func EvictSession(c *gin.Context) {
var req EvictSessionReq
if err := c.ShouldBindJSON(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
if err := db.MarkInactive(req.SessionID); err != nil {
common.ErrorResp(c, err, 500)
return
}
common.SuccessResp(c)
}
Loading
Loading