-
Notifications
You must be signed in to change notification settings - Fork 808
Added watsonx.ai generator #1058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
7c1120d
WatsonX Generator
iamnotcj 17703fc
Updated Format
iamnotcj 8149e90
Minor update
iamnotcj c0b3631
Update
iamnotcj 6c008b8
Update .gitignore
iamnotcj ecaa821
2nd Update
iamnotcj fefcd5e
Update garak/generators/watsonx.py
iamnotcj 4b252fe
Update garak/generators/watsonx.py
iamnotcj 826f2dc
Updated version variable
iamnotcj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,3 +168,4 @@ hitlog.*.jsonl | |
| garak_runs/ | ||
| runs/ | ||
| logs/ | ||
| .DS_Store | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| garak.generators.watsonx | ||
| ======================= | ||
|
|
||
| .. automodule:: garak.generators.watsonx | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| from garak import _config | ||
| from garak.generators.base import Generator | ||
| from typing import List, Union | ||
| import os | ||
| import requests | ||
|
|
||
|
|
||
| class WatsonXGenerator(Generator): | ||
| """ | ||
| This is a generator for watsonx.ai. | ||
|
|
||
| Make sure that you initialize the environment variables: | ||
| 'WATSONX_TOKEN', | ||
| 'WATSONX_URL', | ||
| 'WATSONX_PROJECTID' OR 'WATSONX_DEPLOYID'. | ||
|
|
||
| To use a model that is in the "project" stage initialize the WATSONX_PROJECTID variable with the Project ID of the model. | ||
| To use a tuned model that is deployed, simply initialize the WATSONX_DEPLOYID variable with the Deployment ID of the model. | ||
| """ | ||
|
|
||
| ENV_VAR = "WATSONX_TOKEN" | ||
| URI_ENV_VAR = "WATSONX_URL" | ||
| PID_ENV_VAR = "WATSONX_PROJECTID" | ||
| DID_ENV_VAR = "WATSONX_DEPLOYID" | ||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
| "uri": None, | ||
| "version": "2023-05-29", | ||
| "project_id": "", | ||
| "deployment_id": "", | ||
| "prompt_variable": "input", | ||
| "bearer_token": "", | ||
| "max_tokens": 900, | ||
| } | ||
|
|
||
| generator_family_name = "watsonx" | ||
|
|
||
| def __init__(self, name="", config_root=_config): | ||
| super().__init__(name, config_root=config_root) | ||
| # Initialize and validate api_key | ||
| if self.api_key is not None: | ||
| os.environ[self.ENV_VAR] = self.api_key | ||
|
|
||
| def _set_bearer_token(self, iam_url="https://iam.cloud.ibm.com/identity/token"): | ||
| header = { | ||
| "Content-Type": "application/x-www-form-urlencoded", | ||
| "Accept": "application/json", | ||
| } | ||
| body = ( | ||
| "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + self.api_key | ||
| ) | ||
| response = requests.post(url=iam_url, headers=header, data=body) | ||
| self.bearer_token = "Bearer " + response.json()["access_token"] | ||
|
|
||
| def _generate_with_project(self, payload): | ||
| # Generation via Project ID. | ||
|
|
||
| url = self.uri + f"/ml/v1/text/generation?version={self.version}" | ||
|
|
||
| body = { | ||
| "input": payload, | ||
| "parameters": { | ||
| "decoding_method": "greedy", | ||
| "max_new_tokens": self.max_tokens, | ||
| "min_new_tokens": 0, | ||
| "repetition_penalty": 1, | ||
| }, | ||
| "model_id": self.name, | ||
| "project_id": self.project_id, | ||
| } | ||
|
|
||
| headers = { | ||
| "Accept": "application/json", | ||
| "Content-Type": "application/json", | ||
| "Authorization": self.bearer_token, | ||
| } | ||
|
|
||
| response = requests.post(url=url, headers=headers, json=body) | ||
| return response.json() | ||
|
|
||
| def _generate_with_deployment(self, payload): | ||
| # Generation via Deployment ID. | ||
| url = ( | ||
| self.uri | ||
| + "/ml/v1/deployments/" | ||
| + self.deployment_id | ||
| + f"/text/generation?version={self.version}" | ||
| ) | ||
| body = {"parameters": {"prompt_variables": {self.prompt_variable: payload}}} | ||
| headers = { | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| "Authorization": self.bearer_token, | ||
| } | ||
| response = requests.post(url=url, headers=headers, json=body) | ||
| return response.json() | ||
|
|
||
| def _validate_env_var(self): | ||
| # Initialize and validate url. | ||
| if self.uri is not None: | ||
| pass | ||
| else: | ||
| self.uri = os.getenv("WATSONX_URL", None) | ||
| if self.uri is None: | ||
| raise ValueError( | ||
| f"The {self.URI_ENV_VAR} environment variable is required. Please enter the URL corresponding to the region of your provisioned service instance. \n" | ||
| ) | ||
|
|
||
| # Initialize and validate project_id. | ||
| if self.project_id: | ||
| pass | ||
| else: | ||
| self.project_id = os.getenv("WATSONX_PROJECTID", "") | ||
|
|
||
| # Initialize and validate deployment_id. | ||
| if self.deployment_id: | ||
| pass | ||
| else: | ||
| self.deployment_id = os.getenv("WATSONX_DEPLOYID", "") | ||
|
|
||
| # Check to ensure at least ONE of project_id or deployment_id is populated. | ||
| if not self.project_id and not self.deployment_id: | ||
| raise ValueError( | ||
| f"Either {self.PID_ENV_VAR} or {self.DID_ENV_VAR} is required. Please supply either a Project ID or Deployment ID. \n" | ||
| ) | ||
| return super()._validate_env_var() | ||
|
|
||
| def _call_model( | ||
| self, prompt: str, generations_this_call: int = 1 | ||
| ) -> List[Union[str, None]]: | ||
| if not self.bearer_token: | ||
| self._set_bearer_token() | ||
|
|
||
| # Check if message is empty. If it is, append null byte. | ||
| if not prompt: | ||
| prompt = "\x00" | ||
| print( | ||
| "WARNING: Empty prompt was found. Null byte character appended to prevent API failure." | ||
| ) | ||
|
|
||
| output = "" | ||
| if self.deployment_id: | ||
| output = self._generate_with_deployment(prompt) | ||
| else: | ||
| output = self._generate_with_project(prompt) | ||
|
|
||
| # Parse the output to only contain the output message from the model. Return a list containing that message. | ||
| return ["".join(output["results"][0]["generated_text"])] | ||
|
|
||
|
|
||
| DEFAULT_CLASS = "WatsonXGenerator" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from garak.generators.watsonx import WatsonXGenerator | ||
| import os | ||
| import pytest | ||
| import requests_mock | ||
|
|
||
|
|
||
| DEFAULT_DEPLOYMENT_NAME = "ibm/granite-3-8b-instruct" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def set_fake_env(request) -> None: | ||
| stored_env = { | ||
| WatsonXGenerator.ENV_VAR: os.getenv(WatsonXGenerator.ENV_VAR, None), | ||
| WatsonXGenerator.PID_ENV_VAR: os.getenv(WatsonXGenerator.PID_ENV_VAR, None), | ||
| WatsonXGenerator.URI_ENV_VAR: os.getenv(WatsonXGenerator.URI_ENV_VAR, None), | ||
| WatsonXGenerator.DID_ENV_VAR: os.getenv(WatsonXGenerator.DID_ENV_VAR, None), | ||
| } | ||
|
|
||
| def restore_env(): | ||
| for k, v in stored_env.items(): | ||
| if v is not None: | ||
| os.environ[k] = v | ||
| else: | ||
| del os.environ[k] | ||
|
|
||
| os.environ[WatsonXGenerator.ENV_VAR] = "XXXXXXXXXXXXX" | ||
| os.environ[WatsonXGenerator.PID_ENV_VAR] = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" | ||
| os.environ[WatsonXGenerator.DID_ENV_VAR] = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" | ||
| os.environ[WatsonXGenerator.URI_ENV_VAR] = "https://garak.example.com/" | ||
| request.addfinalizer(restore_env) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("set_fake_env") | ||
| def test_bearer_token(watsonx_compat_mocks): | ||
| with requests_mock.Mocker() as m: | ||
| mock_response = watsonx_compat_mocks["watsonx_bearer_token"] | ||
|
|
||
| extended_request = "identity/token" | ||
|
|
||
| m.post( | ||
| "https://garak.example.com/" + extended_request, json=mock_response["json"] | ||
| ) | ||
|
|
||
| granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME) | ||
| token = granite_llm._set_bearer_token(iam_url="https://garak.example.com/identity/token") | ||
|
|
||
| assert granite_llm.bearer_token == ("Bearer " + mock_response["json"]["access_token"]) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("set_fake_env") | ||
| def test_project(watsonx_compat_mocks): | ||
| with requests_mock.Mocker() as m: | ||
| mock_response = watsonx_compat_mocks["watsonx_generation"] | ||
| extended_request = "/ml/v1/text/generation?version=2023-05-29" | ||
|
|
||
| m.post( | ||
| "https://garak.example.com/" + extended_request, json=mock_response["json"] | ||
| ) | ||
|
|
||
| granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME) | ||
| response = granite_llm._generate_with_project("What is this?") | ||
|
|
||
| assert granite_llm.name == response["model_id"] | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("set_fake_env") | ||
| def test_deployment(watsonx_compat_mocks): | ||
| with requests_mock.Mocker() as m: | ||
| mock_response = watsonx_compat_mocks["watsonx_generation"] | ||
| extended_request = "/ml/v1/deployments/" | ||
| extended_request += "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" | ||
| extended_request += "/text/generation?version=2023-05-29" | ||
|
|
||
| m.post( | ||
| "https://garak.example.com/" + extended_request, json=mock_response["json"] | ||
| ) | ||
|
|
||
| granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME) | ||
| response = granite_llm._generate_with_deployment("What is this?") | ||
|
|
||
| assert granite_llm.name == response["model_id"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| { | ||
| "watsonx_bearer_token": { | ||
| "code": 200, | ||
| "json": { | ||
| "access_token": "fake_token1231231231", | ||
| "refresh_token": "not_supported", | ||
| "token_type": "Bearer", | ||
| "expires_in": 3600, | ||
| "expiration": 1737754747, | ||
| "scope": "ibm openid" | ||
| } | ||
| }, | ||
| "watsonx_generation": { | ||
| "code": 200, | ||
| "json" : { | ||
| "model_id": "ibm/granite-3-8b-instruct", | ||
| "model_version": "1.1.0", | ||
| "created_at": "2025-01-24T20:51:59.520Z", | ||
| "results": [ | ||
| { | ||
| "generated_text": "This is a test generation. :)", | ||
| "generated_token_count": 32, | ||
| "input_token_count": 6, | ||
| "stop_reason": "eos_token" | ||
| } | ||
| ] | ||
| } | ||
| } | ||
| } |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.