Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/stupid-cows-grow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/community": patch
---

allow any chars in delimited identifiers in hanavector
138 changes: 74 additions & 64 deletions libs/langchain-community/src/vectorstores/hanavector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,15 @@ export class HanaDB extends VectorStore {
constructor(embeddings: EmbeddingsInterface, args: HanaDBArgs) {
super(embeddings, args);
this.distanceStrategy = args.distanceStrategy || defaultDistanceStrategy;
this.tableName = HanaDB.sanitizeName(args.tableName || defaultTableName);
this.contentColumn = HanaDB.sanitizeName(
args.contentColumn || defaultContentColumn
);
this.metadataColumn = HanaDB.sanitizeName(
args.metadataColumn || defaultMetadataColumn
);
this.vectorColumn = HanaDB.sanitizeName(
args.vectorColumn || defaultVectorColumn
);
this.tableName = args.tableName || defaultTableName;
this.contentColumn = args.contentColumn || defaultContentColumn;
this.metadataColumn = args.metadataColumn || defaultMetadataColumn;
this.vectorColumn = args.vectorColumn || defaultVectorColumn;
this.vectorColumnLength = HanaDB.sanitizeInt(
args.vectorColumnLength || defaultVectorColumnLength,
-1
);
this.specificMetadataColumns = HanaDB.sanitizeSpecificMetadataColumns(
args.specificMetadataColumns || []
);
this.specificMetadataColumns = args.specificMetadataColumns || [];
this.connection = args.connection;
}

Expand Down Expand Up @@ -234,6 +226,10 @@ export class HanaDB extends VectorStore {
return inputStr.replace(/[^a-zA-Z0-9_]/g, "");
}

public static escapeSqlIdentifier(inputStr: string): string {
return `"${inputStr.replaceAll('"', '""')}"`;
}

/**
* Sanitizes the input to integer. Throws an error if the value is less than lower bound.
* @param inputInt The input to be sanitized.
Expand Down Expand Up @@ -291,10 +287,6 @@ export class HanaDB extends VectorStore {
return metadata;
}

static sanitizeSpecificMetadataColumns(columns: string[]): string[] {
return columns.map((column) => this.sanitizeName(column));
}

/**
* Parses a string representation of a float array and returns an array of numbers.
* @param {string} arrayAsString - The string representation of the array.
Expand All @@ -318,17 +310,20 @@ export class HanaDB extends VectorStore {
columnType: string | string[],
columnLength?: number
): Promise<void> {
const sqlStr = `
SELECT DATA_TYPE_NAME, LENGTH
FROM SYS.TABLE_COLUMNS
WHERE SCHEMA_NAME = CURRENT_SCHEMA
AND TABLE_NAME = ?
AND COLUMN_NAME = ?`;
const query = `
SELECT DATA_TYPE_NAME, LENGTH
FROM SYS.TABLE_COLUMNS
WHERE SCHEMA_NAME = CURRENT_SCHEMA
AND TABLE_NAME = ?
AND COLUMN_NAME = ?`;
const client = this.connection; // Get the connection object
// Prepare the statement with parameter placeholders
const stm = await this.prepareQuery(client, sqlStr);
const statement = await this.prepareQuery(client, query);
// Execute the query with actual parameters to avoid SQL injection
const resultSet = await this.executeStatement(stm, [tableName, columnName]);
const resultSet = await this.executeStatement(statement, [
tableName,
columnName,
]);
if (resultSet.length === 0) {
throw new Error(`Column ${columnName} does not exist`);
} else {
Expand Down Expand Up @@ -356,29 +351,28 @@ export class HanaDB extends VectorStore {
private async createTableIfNotExists() {
const tableExists = await this.tableExists(this.tableName);
if (!tableExists) {
let sqlStr =
`CREATE TABLE "${this.tableName}" (` +
`"${this.contentColumn}" NCLOB, ` +
`"${this.metadataColumn}" NCLOB, ` +
`"${this.vectorColumn}" REAL_VECTOR`;
// Length can either be -1 (QRC01+02-24) or 0 (QRC03-24 onwards)
if (this.vectorColumnLength === -1 || this.vectorColumnLength === 0) {
sqlStr += ");";
} else {
sqlStr += `(${this.vectorColumnLength}));`;
const vectorColumnLength =
this.vectorColumnLength <= 0 ? null : this.vectorColumnLength;
const query = `
CREATE TABLE ${HanaDB.escapeSqlIdentifier(this.tableName)} (
${HanaDB.escapeSqlIdentifier(this.contentColumn)} NCLOB,
${HanaDB.escapeSqlIdentifier(this.metadataColumn)} NCLOB,
${HanaDB.escapeSqlIdentifier(this.vectorColumn)} REAL_VECTOR${
vectorColumnLength ? `(${vectorColumnLength})` : ""
}

)`;
const client = this.connection;
await this.executeQuery(client, sqlStr);
await this.executeQuery(client, query);
}
}

public async tableExists(tableName: string): Promise<boolean> {
const tableExistsSQL = `SELECT COUNT(*) AS COUNT FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA AND TABLE_NAME = ?`;
const tableExistsQuery = `SELECT COUNT(*) AS COUNT FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA AND TABLE_NAME = ?`;
const client = this.connection; // Get the connection object

const stm = await this.prepareQuery(client, tableExistsSQL);
const resultSet = await this.executeStatement(stm, [tableName]);
const statement = await this.prepareQuery(client, tableExistsQuery);
const resultSet = await this.executeStatement(statement, [tableName]);
if (resultSet[0].COUNT === 1) {
// Table does exist
return true;
Expand Down Expand Up @@ -529,8 +523,10 @@ export class HanaDB extends VectorStore {

// Metadata column handling
const selector = this.specificMetadataColumns.includes(key)
? `"${key}"`
: `JSON_VALUE(${this.metadataColumn}, '$.${key}')`;
? HanaDB.escapeSqlIdentifier(key)
: `JSON_VALUE(${HanaDB.escapeSqlIdentifier(
this.metadataColumn
)}, '$.${key}')`;
whereStr += `${selector} ${operator} ${sqlParam}`;
});
return [whereStr, queryTuple];
Expand Down Expand Up @@ -567,7 +563,7 @@ export class HanaDB extends VectorStore {
const defaultIndexName = `${this.tableName}_${distanceFuncName}_idx`;

// Use provided indexName or fallback to default
const finalIndexName = HanaDB.sanitizeName(indexName || defaultIndexName);
const finalIndexName = indexName || defaultIndexName;
// Initialize buildConfig and searchConfig objects
const buildConfig: Record<string, number> = {};
const searchConfig: Record<string, number> = {};
Expand Down Expand Up @@ -623,24 +619,28 @@ export class HanaDB extends VectorStore {
: "";

// Create the base SQL string for index creation
let sqlStr = `CREATE HNSW VECTOR INDEX ${finalIndexName} ON "${this.tableName}" ("${this.vectorColumn}")
SIMILARITY FUNCTION ${distanceFuncName} `;
let query = `
CREATE HNSW VECTOR INDEX ${HanaDB.escapeSqlIdentifier(finalIndexName)}
ON ${HanaDB.escapeSqlIdentifier(
this.tableName
)} (${HanaDB.escapeSqlIdentifier(this.vectorColumn)})
SIMILARITY FUNCTION ${distanceFuncName}`;

// Append buildConfig to the SQL string if provided
if (buildConfigStr) {
sqlStr += `BUILD CONFIGURATION '${buildConfigStr}' `;
query += ` BUILD CONFIGURATION '${buildConfigStr}'`;
}

// Append searchConfig to the SQL string if provided
if (searchConfigStr) {
sqlStr += `SEARCH CONFIGURATION '${searchConfigStr}' `;
query += ` SEARCH CONFIGURATION '${searchConfigStr}'`;
}

// Add the ONLINE option
sqlStr += "ONLINE;";
query += " ONLINE;";

const client = this.connection;
await this.executeQuery(client, sqlStr);
await this.executeQuery(client, query);
}

/**
Expand All @@ -664,10 +664,12 @@ export class HanaDB extends VectorStore {
}

const [whereStr, queryTuple] = this.createWhereByFilter(filter);
const sqlStr = `DELETE FROM "${this.tableName}" ${whereStr}`;
const query = `DELETE FROM ${HanaDB.escapeSqlIdentifier(
this.tableName
)} ${whereStr}`;
const client = this.connection;
const stm = await this.prepareQuery(client, sqlStr);
await this.executeStatement(stm, queryTuple);
const statement = await this.prepareQuery(client, query);
await this.executeStatement(statement, queryTuple);
}

/**
Expand Down Expand Up @@ -757,10 +759,14 @@ export class HanaDB extends VectorStore {
];
});
// Insert data into the table, bulk insert.
const sqlStr = `INSERT INTO "${this.tableName}" ("${this.contentColumn}", "${this.metadataColumn}", "${this.vectorColumn}")
VALUES (?, ?, TO_REAL_VECTOR(?));`;
const stm = await this.prepareQuery(client, sqlStr);
await this.executeStatement(stm, sqlParams);
const query = `
INSERT INTO ${HanaDB.escapeSqlIdentifier(this.tableName)} (
${HanaDB.escapeSqlIdentifier(this.contentColumn)},
${HanaDB.escapeSqlIdentifier(this.metadataColumn)},
${HanaDB.escapeSqlIdentifier(this.vectorColumn)}
) VALUES (?, ?, TO_REAL_VECTOR(?));`;
const statement = await this.prepareQuery(client, query);
await this.executeStatement(statement, sqlParams);
// stm.execBatch(sqlParams);
}

Expand Down Expand Up @@ -839,12 +845,16 @@ export class HanaDB extends VectorStore {
const distanceFuncName = HANA_DISTANCE_FUNCTION[this.distanceStrategy][0];
// Convert the embedding vector to a string for SQL query
const embeddingAsString = sanitizedEmbedding.join(",");
let sqlStr = `SELECT TOP ${sanitizedK}
"${this.contentColumn}",
"${this.metadataColumn}",
TO_NVARCHAR("${this.vectorColumn}") AS VECTOR,
${distanceFuncName}("${this.vectorColumn}", TO_REAL_VECTOR('[${embeddingAsString}]')) AS CS
FROM "${this.tableName}"`;
let query = `
SELECT TOP ${sanitizedK}
${HanaDB.escapeSqlIdentifier(this.contentColumn)},
${HanaDB.escapeSqlIdentifier(this.metadataColumn)},
TO_NVARCHAR(${HanaDB.escapeSqlIdentifier(this.vectorColumn)}) AS VECTOR,
${distanceFuncName}(
${HanaDB.escapeSqlIdentifier(this.vectorColumn)},
TO_REAL_VECTOR('[${embeddingAsString}]')
) AS CS
FROM ${HanaDB.escapeSqlIdentifier(this.tableName)}`;
// Add order by clause to sort by similarity
const orderStr = ` ORDER BY CS ${
HANA_DISTANCE_FUNCTION[this.distanceStrategy][1]
Expand All @@ -853,10 +863,10 @@ export class HanaDB extends VectorStore {
// Prepare and execute the SQL query
const [whereStr, queryTuple] = this.createWhereByFilter(filter);

sqlStr += whereStr + orderStr;
query += whereStr + orderStr;
const client = this.connection;
const stm = await this.prepareQuery(client, sqlStr);
const resultSet = await this.executeStatement(stm, queryTuple);
const statement = await this.prepareQuery(client, query);
const resultSet = await this.executeStatement(statement, queryTuple);
const result: Array<[Document, number, number[]]> = resultSet.map(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(row: any) => {
Expand Down
24 changes: 24 additions & 0 deletions libs/langchain-community/src/vectorstores/tests/hanavector.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { expect } from "@jest/globals";
import { FakeEmbeddings } from "@langchain/core/utils/testing";
import { HanaDB } from "../hanavector.js";

describe("Sanity check tests", () => {
Expand Down Expand Up @@ -37,3 +38,26 @@ describe("Sanity check tests", () => {
]);
});
});

describe("HanaDB tests", () => {
it("should create a new HanaDB instance with unsanitized params", () => {
const hanaDB = new HanaDB(new FakeEmbeddings(), {
connection: {},
});
expect(hanaDB).toBeDefined();
// @ts-expect-error testing private properties
expect(hanaDB.distanceStrategy).toBe("cosine");
// @ts-expect-error testing private properties
expect(hanaDB.tableName).toBe("EMBEDDINGS");
// @ts-expect-error testing private properties
expect(hanaDB.contentColumn).toBe("VEC_TEXT");
// @ts-expect-error testing private properties
expect(hanaDB.metadataColumn).toBe("VEC_META");
// @ts-expect-error testing private properties
expect(hanaDB.vectorColumn).toBe("VEC_VECTOR");
// @ts-expect-error testing private properties
expect(hanaDB.vectorColumnLength).toBe(-1);
// @ts-expect-error testing private properties
expect(hanaDB.specificMetadataColumns).toEqual([]);
});
});
Loading