Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions internal/commands/paged_activity_request_command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package commands

import (
"encoding/json"
"net/http"
"receipt-wrangler/api/internal/structs"
"receipt-wrangler/api/internal/utils"
)

type PagedActivityRequestCommand struct {
PagedRequestCommand
GroupIds []uint `json:"groupIds"`
}

func (command *PagedActivityRequestCommand) LoadDataFromRequest(w http.ResponseWriter, r *http.Request) error {
bytes, err := utils.GetBodyData(w, r)
if err != nil {
return err
}

err = json.Unmarshal(bytes, &command)
if err != nil {
return err
}

return nil
}

func (command *PagedActivityRequestCommand) Validate() structs.ValidatorError {
vErrs := command.PagedRequestCommand.Validate()

if len(command.GroupIds) == 0 {
vErrs.Errors["groupIds"] = "Must provide at least one group id"
}

return vErrs
}
10 changes: 10 additions & 0 deletions internal/handlers/generic_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func HandleRequest(handler structs.Handler) {
}
}

if len(handler.GroupRole) > 0 && len(handler.GroupId) == 0 && len(handler.GroupIds) == 0 {
utils.WriteCustomErrorResponse(handler.Writer, "Group ID is required to validate group role", http.StatusForbidden)
return
}

if len(handler.GroupRole) > 0 && len(handler.GroupId) > 0 {
groupService := services.NewGroupService(nil)
token := structs.GetJWT(handler.Request)
Expand All @@ -60,6 +65,11 @@ func HandleRequest(handler structs.Handler) {
}
}

if len(handler.GroupRole) > 0 && len(handler.GroupIds) == 0 && len(handler.GroupId) == 0 {
utils.WriteCustomErrorResponse(handler.Writer, "Group IDs are required to validate group role", http.StatusForbidden)
return
}

if len(handler.GroupRole) > 0 && len(handler.GroupIds) > 0 {
groupService := services.NewGroupService(nil)
token := structs.GetJWT(handler.Request)
Expand Down
59 changes: 58 additions & 1 deletion internal/handlers/generic_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ func TestShouldAcceptReceiptAccessBasedOnGroup(t *testing.T) {
Writer: w,
Request: r,
ResponseType: constants.ApplicationJson,
GroupRole: models.OWNER,
ReceiptId: "1",
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil
Expand Down Expand Up @@ -278,6 +277,64 @@ func TestShouldAcceptReceiptsAccessBasedOnGroup(t *testing.T) {
}
}

func TestShouldRejectAccessBasedOnEmptyGroupId(t *testing.T) {
defer tearDownGenericHandlerTest()
reader := strings.NewReader("")
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/api", reader)

newContext := context.WithValue(r.Context(), jwtmiddleware.ContextKey{}, &validator.ValidatedClaims{CustomClaims: &structs.Claims{UserId: 1}})
r = r.WithContext(newContext)

repositories.CreateTestGroupWithUsers()

handler := structs.Handler{
Writer: w,
Request: r,
ResponseType: constants.ApplicationJson,
GroupRole: models.VIEWER,
GroupId: "",
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil
},
}

HandleRequest(handler)

if w.Result().StatusCode != http.StatusForbidden {
utils.PrintTestError(t, w.Result().StatusCode, http.StatusOK)
}
}

func TestShouldRejectAccessBasedOnEmptyGroupIds(t *testing.T) {
defer tearDownGenericHandlerTest()
reader := strings.NewReader("")
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/api", reader)

newContext := context.WithValue(r.Context(), jwtmiddleware.ContextKey{}, &validator.ValidatedClaims{CustomClaims: &structs.Claims{UserId: 1}})
r = r.WithContext(newContext)

repositories.CreateTestGroupWithUsers()

handler := structs.Handler{
Writer: w,
Request: r,
ResponseType: constants.ApplicationJson,
GroupRole: models.VIEWER,
GroupIds: []string{},
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil
},
}

HandleRequest(handler)

if w.Result().StatusCode != http.StatusForbidden {
utils.PrintTestError(t, w.Result().StatusCode, http.StatusOK)
}
}

func TestShouldRejectReceiptAccessBasedOnWrongGroupRole(t *testing.T) {
defer tearDownGenericHandlerTest()
reader := strings.NewReader("")
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/receipts.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func CreateReceipt(w http.ResponseWriter, r *http.Request) {
ResponseType: constants.ApplicationJson,
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {
receiptRepository := repositories.NewReceiptRepository(nil)
createdReceipt, err := receiptRepository.CreateReceipt(command, token.UserId)
createdReceipt, err := receiptRepository.CreateReceipt(command, token.UserId, true)
if err != nil {
return http.StatusInternalServerError, err
}
Expand Down
153 changes: 153 additions & 0 deletions internal/handlers/system_task.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package handlers

import (
"encoding/json"
"github.com/go-chi/chi/v5"
"net/http"
"receipt-wrangler/api/internal/commands"
"receipt-wrangler/api/internal/constants"
"receipt-wrangler/api/internal/logging"
"receipt-wrangler/api/internal/models"
"receipt-wrangler/api/internal/repositories"
"receipt-wrangler/api/internal/structs"
"receipt-wrangler/api/internal/utils"
"receipt-wrangler/api/internal/wranglerasynq"
)

func GetSystemTasks(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -60,3 +64,152 @@ func GetSystemTasks(w http.ResponseWriter, r *http.Request) {

HandleRequest(handler)
}

func GetActivitiesForGroups(w http.ResponseWriter, r *http.Request) {
command := commands.PagedActivityRequestCommand{}
err := command.LoadDataFromRequest(w, r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

stringGroupIds := make([]string, 0)
for _, groupId := range command.GroupIds {
stringGroupIds = append(stringGroupIds, utils.UintToString(groupId))
}

handler := structs.Handler{
ErrorMessage: "Error getting group activities",
Writer: w,
Request: r,
GroupIds: stringGroupIds,
GroupRole: models.VIEWER,
ResponseType: constants.ApplicationJson,
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {

vErr := command.Validate()
if len(vErr.Errors) > 0 {
structs.WriteValidatorErrorResponse(w, vErr, http.StatusBadRequest)
return 0, nil
}

systemTaskRepository := repositories.NewSystemTaskRepository(nil)
activities, count, err := systemTaskRepository.GetPagedActivities(command)
if err != nil {
return http.StatusInternalServerError, err
}

err = wranglerasynq.SetActivityCanBeRestarted(&activities)
if err != nil {
return http.StatusInternalServerError, err
}

pagedData := structs.PagedData{}
data := make([]any, 0)

for i := 0; i < len(activities); i++ {
data = append(data, activities[i])
}

pagedData.Data = data
pagedData.TotalCount = count

responseBytes, err := utils.MarshalResponseData(pagedData)
if err != nil {
return http.StatusInternalServerError, err
}

w.WriteHeader(http.StatusOK)
w.Write(responseBytes)

return 0, nil
},
}

HandleRequest(handler)
}

func RerunActivity(w http.ResponseWriter, r *http.Request) {
systemTaskRepository := repositories.NewSystemTaskRepository(nil)
inspector, err := wranglerasynq.GetAsynqInspector()
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}

systemTaskId := chi.URLParam(r, "id")
systemTaskUintId, err := utils.StringToUint(systemTaskId)
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusBadRequest)
return
}

systemTask, err := systemTaskRepository.GetSystemTaskById(systemTaskUintId)
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}

if systemTask.Type != models.QUICK_SCAN {
logging.LogStd(logging.LOG_LEVEL_ERROR, "Only quick scan activities can be rerun")
w.WriteHeader(http.StatusBadRequest)
return
}

if systemTask.AssociatedSystemTaskId == nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, "Associated system task id is required to rerun quick scan activity")
w.WriteHeader(http.StatusBadRequest)
return
}

parentSystemTask, err := systemTaskRepository.GetSystemTaskById(*systemTask.AssociatedSystemTaskId)
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}

if parentSystemTask.AsynqTaskId == "" {
logging.LogStd(logging.LOG_LEVEL_ERROR, "Parent system task does not have an asynq task id")
w.WriteHeader(http.StatusBadRequest)
return
}

taskInfo, err := inspector.GetTaskInfo(string(models.QuickScanQueue), parentSystemTask.AsynqTaskId)
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}

var payload wranglerasynq.QuickScanTaskPayload
err = json.Unmarshal(taskInfo.Payload, &payload)
if err != nil {
logging.LogStd(logging.LOG_LEVEL_ERROR, err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}

stringGroupId := utils.UintToString(payload.GroupId)

handler := structs.Handler{
ErrorMessage: "Error rerunning activity",
Writer: w,
Request: r,
GroupId: stringGroupId,
GroupRole: models.EDITOR,
HandlerFunction: func(w http.ResponseWriter, r *http.Request) (int, error) {
err = inspector.RunTask(string(models.QuickScanQueue), parentSystemTask.AsynqTaskId)
if err != nil {
return http.StatusInternalServerError, err
}

return 0, nil
},
}

HandleRequest(handler)
}
5 changes: 4 additions & 1 deletion internal/models/widget_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type WidgetType string
const (
GROUP_SUMMARY WidgetType = "GROUP_SUMMARY"
FILTERED_RECEIPTS WidgetType = "FILTERED_RECEIPTS"
GROUP_ACTIVITY WidgetType = "GROUP_ACTIVITY"
)

func (widgetType *WidgetType) Scan(value string) error {
Expand All @@ -18,7 +19,9 @@ func (widgetType *WidgetType) Scan(value string) error {
}

func (widgetType WidgetType) Value() (driver.Value, error) {
if widgetType != GROUP_SUMMARY && widgetType != FILTERED_RECEIPTS {
if widgetType != GROUP_SUMMARY &&
widgetType != FILTERED_RECEIPTS &&
widgetType != GROUP_ACTIVITY {
return nil, errors.New("invalid widget type")
}
return string(widgetType), nil
Expand Down
4 changes: 1 addition & 3 deletions internal/repositories/dashboards.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ func (repository *DashboardRepository) CreateDashboard(command commands.UpsertDa
groupId, _ = utils.StringToUint(command.GroupId)

for i, widget := range command.Widgets {
configuration := []byte("{}")

widgets[i] = models.Widget{
Name: widget.Name,
WidgetType: widget.WidgetType,
Configuration: configuration,
Configuration: widget.Configuration,
}
}

Expand Down
Loading