diff --git a/cmd/args.go b/cmd/args.go index d309b825..9855fa6d 100644 --- a/cmd/args.go +++ b/cmd/args.go @@ -28,6 +28,7 @@ type Args struct { PoCType string ReportFormat string HarFilePath string + CustomBlindXSSPayloadFile string Timeout int Delay int Concurrence int diff --git a/cmd/root.go b/cmd/root.go index ca84fc16..cd543ee4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -69,6 +69,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&args.Cookie, "cookie", "C", "", "Add custom cookies to the request. Example: -C 'sessionid=abc123'") rootCmd.PersistentFlags().StringVarP(&args.Data, "data", "d", "", "Use POST method and add body data. Example: -d 'username=admin&password=admin'") rootCmd.PersistentFlags().StringVar(&args.CustomPayload, "custom-payload", "", "Load custom payloads from a file. Example: --custom-payload 'payloads.txt'") + rootCmd.PersistentFlags().StringVar(&args.CustomBlindXSSPayloadFile, "custom-blind-xss-payload", "", "Load custom blind XSS payloads from a file. Example: --custom-blind-xss-payload 'payloads.txt'") rootCmd.PersistentFlags().StringVar(&args.CustomAlertValue, "custom-alert-value", "1", "Set a custom alert value. Example: --custom-alert-value 'document.cookie'") rootCmd.PersistentFlags().StringVar(&args.CustomAlertType, "custom-alert-type", "none", "Set a custom alert type. Example: --custom-alert-type 'str,none'") rootCmd.PersistentFlags().StringVar(&args.UserAgent, "user-agent", "", "Set a custom User-Agent header. Example: --user-agent 'Mozilla/5.0'") @@ -152,11 +153,12 @@ func initConfig() { options = model.Options{ Header: args.Header, Cookie: args.Cookie, - UniqParam: args.P, - BlindURL: args.Blind, - CustomPayloadFile: args.CustomPayload, - CustomAlertValue: args.CustomAlertValue, - CustomAlertType: args.CustomAlertType, + UniqParam: args.P, + BlindURL: args.Blind, + CustomPayloadFile: args.CustomPayload, + CustomBlindXSSPayloadFile: args.CustomBlindXSSPayloadFile, + CustomAlertValue: args.CustomAlertValue, + CustomAlertType: args.CustomAlertType, Data: args.Data, UserAgent: args.UserAgent, OutputFile: args.Output, @@ -225,6 +227,9 @@ func initConfig() { if args.CustomPayload == "" && cfgOptions.CustomPayloadFile != "" { options.CustomPayloadFile = cfgOptions.CustomPayloadFile } + if args.CustomBlindXSSPayloadFile == "" && cfgOptions.CustomBlindXSSPayloadFile != "" { + options.CustomBlindXSSPayloadFile = cfgOptions.CustomBlindXSSPayloadFile + } if args.CustomAlertValue == DefaultCustomAlertValue && cfgOptions.CustomAlertValue != "" { options.CustomAlertValue = cfgOptions.CustomAlertValue } diff --git a/pkg/model/options.go b/pkg/model/options.go index d5044401..039453cd 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -30,6 +30,7 @@ type Options struct { // Feature Options BlindURL string `json:"blind,omitempty"` CustomPayloadFile string `json:"custom-payload-file,omitempty"` + CustomBlindXSSPayloadFile string `json:"custom-blind-xss-payload-file,omitempty"` CustomAlertValue string `json:"custom-alert-value,omitempty"` CustomAlertType string `json:"custom-alert-type,omitempty"` OnlyDiscovery bool `json:"only-discovery,omitempty"` diff --git a/pkg/scanning/scan.go b/pkg/scanning/scan.go index 067edeb2..4feb1346 100644 --- a/pkg/scanning/scan.go +++ b/pkg/scanning/scan.go @@ -161,6 +161,15 @@ func Scan(target string, options model.Options, sid string) (model.Result, error } // generatePayloads generates XSS payloads based on discovery results. +// getBlindCallbackURL determines the correct format for the blind callback URL. +// It assumes blindURL is not empty. +func getBlindCallbackURL(blindURL string) string { + if strings.HasPrefix(blindURL, "https://") || strings.HasPrefix(blindURL, "http://") { + return blindURL + } + return "//" + blindURL +} + func generatePayloads(target string, options model.Options, policy map[string]string, pathReflection map[int]string, params map[string]model.ParamResult) (map[*http.Request]map[string]string, []string) { query := make(map[*http.Request]map[string]string) var durls []string @@ -398,12 +407,7 @@ func generatePayloads(target string, options model.Options, policy map[string]st // Blind Payload if options.BlindURL != "" { bpayloads := payload.GetBlindPayload() - var bcallback string - if strings.HasPrefix(options.BlindURL, "https://") || strings.HasPrefix(options.BlindURL, "http://") { - bcallback = options.BlindURL - } else { - bcallback = "//" + options.BlindURL - } + bcallback := getBlindCallbackURL(options.BlindURL) for _, bpayload := range bpayloads { bp := strings.Replace(bpayload, "CALLBACKURL", bcallback, 10) tq, tm := optimization.MakeHeaderQuery(target, "Referer", bp, options) @@ -432,6 +436,56 @@ func generatePayloads(target string, options model.Options, policy map[string]st printing.DalLog("SYSTEM", "Added blind XSS payloads with callback URL: "+options.BlindURL, options) } + // Custom Blind XSS Payloads from file + if options.CustomBlindXSSPayloadFile != "" { + fileInfo, statErr := os.Stat(options.CustomBlindXSSPayloadFile) + if os.IsNotExist(statErr) { + printing.DalLog("SYSTEM", "Failed to load custom blind XSS payload file: "+options.CustomBlindXSSPayloadFile+" (file not found)", options) + } else if statErr != nil { + printing.DalLog("SYSTEM", "Failed to load custom blind XSS payload file: "+options.CustomBlindXSSPayloadFile+" ("+statErr.Error()+")", options) + } else if fileInfo.IsDir() { + printing.DalLog("SYSTEM", "Failed to load custom blind XSS payload file: "+options.CustomBlindXSSPayloadFile+" (path is a directory)", options) + } else { + // File exists and is not a directory, proceed to read it + payloadLines, readErr := voltFile.ReadLinesOrLiteral(options.CustomBlindXSSPayloadFile) + if readErr != nil { + printing.DalLog("SYSTEM", "Failed to read custom blind XSS payload file: "+options.CustomBlindXSSPayloadFile+" ("+readErr.Error()+")", options) + } else { + var bcallback string + if options.BlindURL != "" { + bcallback = getBlindCallbackURL(options.BlindURL) + } + + addedPayloadCount := 0 + for _, customPayload := range payloadLines { + if customPayload != "" { + addedPayloadCount++ + actualPayload := customPayload + if options.BlindURL != "" { // Only replace if BlindURL is set + actualPayload = strings.Replace(customPayload, "CALLBACKURL", bcallback, -1) + } + + for k, v := range params { + if optimization.CheckInspectionParam(options, k) { + ptype := "" + for _, av := range v.Chars { + if strings.Contains(av, "PTYPE:") { + ptype = GetPType(av) + } + } + // Use only NaN encoder to avoid encoding issues with custom payloads + tq, tm := optimization.MakeRequestQuery(target, k, actualPayload, "toBlind"+ptype, "toBlind", NaN, options) + tm["payload"] = "toBlind" + query[tq] = tm + } + } + } + } + printing.DalLog("SYSTEM", "Added "+strconv.Itoa(addedPayloadCount)+" custom blind XSS payloads from file: "+options.CustomBlindXSSPayloadFile, options) + } + } + } + // Remote Payloads if options.RemotePayloads != "" { rp := strings.Split(options.RemotePayloads, ",") diff --git a/pkg/scanning/scan_test.go b/pkg/scanning/scan_test.go index 1e4133a3..19986c3c 100644 --- a/pkg/scanning/scan_test.go +++ b/pkg/scanning/scan_test.go @@ -2,13 +2,18 @@ package scanning import ( "fmt" + "io" "net/http" "net/http/httptest" + "os" "strings" + "sync" "testing" "time" "github.com/hahwul/dalfox/v2/pkg/model" + "github.com/logrusorgru/aurora" + "github.com/stretchr/testify/assert" ) // mockServer creates a test server that reflects query parameters and path in its response @@ -79,6 +84,198 @@ func Test_shouldIgnoreReturn(t *testing.T) { } } +// createTempPayloadFile creates a temporary file with the given content. +// It returns the path to the temporary file and a cleanup function. +func createTempPayloadFile(t *testing.T, content string) (string, func()) { + t.Helper() + tmpFile, err := os.CreateTemp("", "test-payloads-*.txt") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + if _, err := tmpFile.WriteString(content); err != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) + t.Fatalf("Failed to write to temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + os.Remove(tmpFile.Name()) + t.Fatalf("Failed to close temp file: %v", err) + } + return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) } +} + +// captureOutput captures stdout and stderr during the execution of a function. +func captureOutput(f func()) (string, string) { + oldStdout := os.Stdout + oldStderr := os.Stderr + rOut, wOut, _ := os.Pipe() + rErr, wErr, _ := os.Pipe() + os.Stdout = wOut + os.Stderr = wErr + + f() + + wOut.Close() + wErr.Close() + os.Stdout = oldStdout + os.Stderr = oldStderr + + var outBuf, errBuf strings.Builder + // Use a WaitGroup to wait for copying to finish + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(&outBuf, rOut) + }() + go func() { + defer wg.Done() + io.Copy(&errBuf, rErr) + }() + + wg.Wait() + return outBuf.String(), errBuf.String() +} + +func TestGeneratePayloads_CustomBlindXSS(t *testing.T) { + server := mockServerForScanTest() + defer server.Close() + + baseOptions := model.Options{ + Concurrence: 1, + Format: "plain", + Silence: false, // Set to false to capture logs + NoSpinner: true, + CustomAlertType: "none", + AuroraObject: aurora.NewAurora(false), // Assuming NoColor is true for tests + Scan: make(map[string]model.Scan), + PathReflection: make(map[int]string), + Mutex: &sync.Mutex{}, + } + + params := map[string]model.ParamResult{ + "q": { + Name: "q", + Type: "URL", + Reflected: true, + Chars: []string{}, + }, + } + policy := map[string]string{"Content-Type": "text/html"} + pathReflection := make(map[int]string) + + t.Run("Valid custom blind payload file with --blind URL", func(t *testing.T) { + payloadContent := "blindy1\nblindy2" + payloadFile, cleanup := createTempPayloadFile(t, payloadContent) + defer cleanup() + + options := baseOptions + options.CustomBlindXSSPayloadFile = payloadFile + options.BlindURL = "test-callback.com" + options.UniqParam = []string{"q"} // Ensure params are processed + + var generatedQueries map[*http.Request]map[string]string + var logOutput string + + stdout, stderr := captureOutput(func() { + generatedQueries, _ = generatePayloads(server.URL+"/?q=test", options, policy, pathReflection, params) + }) + logOutput = stdout + stderr // Combine stdout and stderr + + assert.Contains(t, logOutput, "Added 2 custom blind XSS payloads from file: "+payloadFile) + + foundPayload1 := false + foundPayload2 := false + expectedPayload1 := strings.Replace("blindy1", "CALLBACKURL", "//"+options.BlindURL, -1) + expectedPayload2 := strings.Replace("blindy2", "CALLBACKURL", "//"+options.BlindURL, -1) + + for req, meta := range generatedQueries { + if meta["type"] == "toBlind" && meta["payload"] == "toBlind" { // Check our specific type for these payloads + // Check if the payload in the query matches one of our expected transformed payloads + // This requires knowing how MakeRequestQuery structures the request. + // Assuming payload is in query parameter 'q' for this test. + queryValues := req.URL.Query() + if queryValues.Get("q") == expectedPayload1 { + foundPayload1 = true + } + if queryValues.Get("q") == expectedPayload2 { + foundPayload2 = true + } + } + } + assert.True(t, foundPayload1, "Expected payload 1 not found or not correctly transformed") + assert.True(t, foundPayload2, "Expected payload 2 not found or not correctly transformed") + }) + + t.Run("Custom blind payload file with CALLBACKURL but no --blind flag", func(t *testing.T) { + payloadContent := "blindy3" + payloadFile, cleanup := createTempPayloadFile(t, payloadContent) + defer cleanup() + + options := baseOptions + options.CustomBlindXSSPayloadFile = payloadFile + options.BlindURL = "" // No blind URL + options.UniqParam = []string{"q"} + + var generatedQueries map[*http.Request]map[string]string + stdout, stderr := captureOutput(func() { + generatedQueries, _ = generatePayloads(server.URL+"/?q=test", options, policy, pathReflection, params) + }) + logOutput := stdout + stderr // Combine stdout and stderr + + assert.Contains(t, logOutput, "Added 1 custom blind XSS payloads from file: "+payloadFile) + foundPayload := false + expectedPayload := "blindy3" // CALLBACKURL should not be replaced + + for req, meta := range generatedQueries { + if meta["type"] == "toBlind" && meta["payload"] == "toBlind" { + if req.URL.Query().Get("q") == expectedPayload { + foundPayload = true + break + } + } + } + assert.True(t, foundPayload, "Expected payload with unreplaced CALLBACKURL not found") + }) + + t.Run("Invalid non-existent custom blind payload file", func(t *testing.T) { + options := baseOptions + options.CustomBlindXSSPayloadFile = "nonexistentfile.txt" + options.UniqParam = []string{"q"} + + stdout, stderr := captureOutput(func() { + _, _ = generatePayloads(server.URL+"/?q=test", options, policy, pathReflection, params) + }) + logOutput := stdout + stderr // Combine stdout and stderr + + assert.Contains(t, logOutput, "Failed to load custom blind XSS payload file: nonexistentfile.txt") + // Check that no payloads of type "toBlind" were added due to this specific file error + // (assuming other payload generation might still occur) + customBlindPayloadsFound := false + assert.False(t, customBlindPayloadsFound, "Queries should not include payloads from a non-existent file if logic prevents it after error") + }) + + t.Run("Empty custom blind payload file", func(t *testing.T) { + payloadFile, cleanup := createTempPayloadFile(t, "") + defer cleanup() + + options := baseOptions + options.CustomBlindXSSPayloadFile = payloadFile + options.UniqParam = []string{"q"} + + stdout, stderr := captureOutput(func() { + _, _ = generatePayloads(server.URL+"/?q=test", options, policy, pathReflection, params) + }) + logOutput := stdout + stderr // Combine stdout and stderr + + assert.Contains(t, logOutput, "Added 0 custom blind XSS payloads from file: "+payloadFile) + // Verify no queries were generated specifically from this empty file. + // Similar to the above, this assumes no other "toBlind" payloads would be generated, + // or relies on the specific log message for confirmation. + }) +} + func Test_generatePayloads(t *testing.T) { // Create a mock server server := mockServerForScanTest()