diff --git a/api/alembic/env.py b/api/alembic/env.py index 60871bf8c..240141af5 100644 --- a/api/alembic/env.py +++ b/api/alembic/env.py @@ -35,7 +35,15 @@ def include_object(object, name, type_, reflected, compare_to): # Remove the sqlite+aiosqlite:// prefix and use sqlite:// for Alembic # Alembic needs a synchronous connection URL (uses sqlite3, not aiosqlite) -sync_url = DATABASE_URL.replace("sqlite+aiosqlite:///", "sqlite:///") +# For PostgreSQL, replace postgresql+asyncpg:// with postgresql:// +if DATABASE_URL.startswith("sqlite+aiosqlite:///"): + sync_url = DATABASE_URL.replace("sqlite+aiosqlite:///", "sqlite:///") +elif DATABASE_URL.startswith("postgresql+asyncpg://"): + sync_url = DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://") +else: + # Fallback: use as-is if format is unexpected + sync_url = DATABASE_URL + config.set_main_option("sqlalchemy.url", sync_url) diff --git a/api/alembic/versions/a1b2c3d4e5f6_convert_uuid_columns_for_postgres.py b/api/alembic/versions/a1b2c3d4e5f6_convert_uuid_columns_for_postgres.py new file mode 100644 index 000000000..648bdf81d --- /dev/null +++ b/api/alembic/versions/a1b2c3d4e5f6_convert_uuid_columns_for_postgres.py @@ -0,0 +1,55 @@ +"""Convert user and oauth_account UUID columns to native UUID type for PostgreSQL + +This migration fixes the CHAR(36) vs native UUID type mismatch that causes +PostgreSQL to reject queries with the error: + "operator does not exist: character = uuid" + +FastAPI-users expects native UUID columns in PostgreSQL. The initial migration +created CHAR(36) columns (an SQLite compatibility workaround). This migration +converts those columns to the native UUID type on PostgreSQL. It is a no-op +for SQLite, which has no native UUID type. + +Revision ID: a1b2c3d4e5f6 +Revises: 4937b0e0647c +Create Date: 2026-02-28 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "4937b0e0647c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Convert CHAR(36) UUID columns to native UUID type on PostgreSQL.""" + connection = op.get_bind() + + if connection.dialect.name != "postgresql": + # SQLite has no native UUID type — nothing to do + return + + # Convert user.id from CHAR(36) to native UUID + op.execute(sa.text('ALTER TABLE "user" ALTER COLUMN id TYPE uuid USING id::uuid')) + + # Convert oauth_account.id and oauth_account.user_id from CHAR(36) to native UUID + op.execute(sa.text("ALTER TABLE oauth_account ALTER COLUMN id TYPE uuid USING id::uuid")) + op.execute(sa.text("ALTER TABLE oauth_account ALTER COLUMN user_id TYPE uuid USING user_id::uuid")) + + +def downgrade() -> None: + """Revert native UUID columns back to VARCHAR(36) on PostgreSQL.""" + connection = op.get_bind() + + if connection.dialect.name != "postgresql": + return + + op.execute(sa.text('ALTER TABLE "user" ALTER COLUMN id TYPE varchar(36) USING id::text')) + op.execute(sa.text("ALTER TABLE oauth_account ALTER COLUMN id TYPE varchar(36) USING id::text")) + op.execute(sa.text("ALTER TABLE oauth_account ALTER COLUMN user_id TYPE varchar(36) USING user_id::text")) diff --git a/api/alembic/versions/c175b784119c_create_oauth_account_table.py b/api/alembic/versions/c175b784119c_create_oauth_account_table.py index 4e8ff059c..1d7ef571f 100644 --- a/api/alembic/versions/c175b784119c_create_oauth_account_table.py +++ b/api/alembic/versions/c175b784119c_create_oauth_account_table.py @@ -11,6 +11,8 @@ from alembic import op import sqlalchemy as sa +from transformerlab.db.migration_utils import table_exists + # revision identifiers, used by Alembic. revision: str = "c175b784119c" @@ -23,14 +25,7 @@ def upgrade() -> None: """Create oauth_account table.""" connection = op.get_bind() - # Helper function to check if table exists - def table_exists(table_name: str) -> bool: - result = connection.execute( - sa.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:name"), {"name": table_name} - ) - return result.fetchone() is not None - - if not table_exists("oauth_account"): + if not table_exists(connection, "oauth_account"): op.create_table( "oauth_account", sa.Column("id", sa.CHAR(length=36), nullable=False), diff --git a/api/alembic/versions/c78d76a6d65c_add_team_id_to_config_table.py b/api/alembic/versions/c78d76a6d65c_add_team_id_to_config_table.py index ae3462010..1a1b190a0 100644 --- a/api/alembic/versions/c78d76a6d65c_add_team_id_to_config_table.py +++ b/api/alembic/versions/c78d76a6d65c_add_team_id_to_config_table.py @@ -23,86 +23,76 @@ def upgrade() -> None: """Upgrade schema.""" connection = op.get_bind() - # Check existing columns - column_result = connection.execute(sa.text("PRAGMA table_info(config)")) - existing_columns = [row[1] for row in column_result.fetchall()] - - # Get existing indexes by querying SQLite directly - # SQLite stores unique constraints as unique indexes - index_result = connection.execute( - sa.text("SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='config'") - ) - existing_index_names = [row[0] for row in index_result.fetchall()] - - # Add columns first (outside batch mode to avoid circular dependency) - # Only add if they don't already exist - if "user_id" not in existing_columns: - op.add_column("config", sa.Column("user_id", sa.String(), nullable=True)) - if "team_id" not in existing_columns: - op.add_column("config", sa.Column("team_id", sa.String(), nullable=True)) - - # Handle indexes outside of batch mode to avoid type inference issues - # Drop existing unique index on key if it exists (to recreate as non-unique) - if "ix_config_key" in existing_index_names: - # Check if it's unique by querying the index definition - index_info = connection.execute( - sa.text("SELECT sql FROM sqlite_master WHERE type='index' AND name='ix_config_key'") - ).fetchone() - if index_info and index_info[0] and "UNIQUE" in index_info[0].upper(): - # Drop the unique index using raw SQL to avoid batch mode issues - connection.execute(sa.text("DROP INDEX IF EXISTS ix_config_key")) - existing_index_names.remove("ix_config_key") # Update our list - - # Create new indexes (non-unique) - these can be done outside batch mode - if "ix_config_key" not in existing_index_names: + # Add columns (outside batch mode to avoid circular dependency) + op.add_column("config", sa.Column("user_id", sa.String(), nullable=True)) + op.add_column("config", sa.Column("team_id", sa.String(), nullable=True)) + + # Drop old unique index on key if it exists, then recreate as non-unique + try: + op.drop_index("ix_config_key", table_name="config") + except Exception: + pass # Index doesn't exist or already dropped + + # Create indexes (will fail silently if they already exist in some databases) + try: op.create_index("ix_config_key", "config", ["key"], unique=False) - if "ix_config_user_id" not in existing_index_names: + except Exception: + pass + + try: op.create_index("ix_config_user_id", "config", ["user_id"], unique=False) - if "ix_config_team_id" not in existing_index_names: + except Exception: + pass + + try: op.create_index("ix_config_team_id", "config", ["team_id"], unique=False) + except Exception: + pass - # For SQLite, unique constraints are stored as unique indexes - # Create the unique constraint as a unique index using raw SQL to avoid batch mode issues - if "uq_config_user_team_key" not in existing_index_names: - connection.execute( - sa.text("CREATE UNIQUE INDEX IF NOT EXISTS uq_config_user_team_key ON config(user_id, team_id, key)") - ) + # Create unique constraint on (user_id, team_id, key) + try: + op.create_unique_constraint("uq_config_user_team_key", "config", ["user_id", "team_id", "key"]) + except Exception: + pass # Constraint already exists # Migrate existing configs to admin user's first team # Note: Don't call connection.commit() - Alembic manages transactions - connection = op.get_bind() + # Find admin user's first team + users_teams = sa.table("users_teams", sa.column("user_id"), sa.column("team_id")) + users = sa.table("user", sa.column("id"), sa.column("email")) + admin_team_result = connection.execute( - sa.text(""" - SELECT ut.team_id - FROM users_teams ut - JOIN user u ON ut.user_id = u.id - WHERE u.email = 'admin@example.com' - LIMIT 1 - """) + sa.select(users_teams.c.team_id) + .select_from(users_teams.join(users, users_teams.c.user_id == users.c.id)) + .where(users.c.email == "admin@example.com") + .limit(1) ) admin_team_row = admin_team_result.fetchone() if admin_team_row: admin_team_id = admin_team_row[0] # Update all existing configs (where team_id is NULL) to use admin team + config_table = sa.table("config", sa.column("team_id")) connection.execute( - sa.text("UPDATE config SET team_id = :team_id WHERE team_id IS NULL"), {"team_id": admin_team_id} + sa.update(config_table).where(config_table.c.team_id.is_(None)).values(team_id=admin_team_id) ) print(f"✅ Migrated existing configs to team {admin_team_id}") else: # If no admin team found, try to get any user's first team - any_team_result = connection.execute(sa.text("SELECT team_id FROM users_teams LIMIT 1")) + any_team_result = connection.execute(sa.select(users_teams.c.team_id).limit(1)) any_team_row = any_team_result.fetchone() if any_team_row: any_team_id = any_team_row[0] + config_table = sa.table("config", sa.column("team_id")) connection.execute( - sa.text("UPDATE config SET team_id = :team_id WHERE team_id IS NULL"), {"team_id": any_team_id} + sa.update(config_table).where(config_table.c.team_id.is_(None)).values(team_id=any_team_id) ) print(f"✅ Migrated existing configs to team {any_team_id}") else: # No teams found, delete existing configs - deleted_count = connection.execute(sa.text("DELETE FROM config WHERE team_id IS NULL")).rowcount + config_table = sa.table("config", sa.column("team_id")) + deleted_count = connection.execute(sa.delete(config_table).where(config_table.c.team_id.is_(None))).rowcount print(f"⚠️ No teams found, deleted {deleted_count} config entries") # ### end Alembic commands ### @@ -111,33 +101,31 @@ def downgrade() -> None: """Downgrade schema.""" connection = op.get_bind() - # Check existing indexes - index_result = connection.execute( - sa.text("SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='config'") - ) - existing_index_names = [row[0] for row in index_result.fetchall()] - - # Check existing columns - column_result = connection.execute(sa.text("PRAGMA table_info(config)")) - existing_columns = [row[1] for row in column_result.fetchall()] - - # Drop indexes and constraints outside of batch mode to avoid type inference issues - # Drop unique constraint (stored as unique index in SQLite) - if "uq_config_user_team_key" in existing_index_names: - connection.execute(sa.text("DROP INDEX IF EXISTS uq_config_user_team_key")) + # Drop unique constraint + try: + op.drop_constraint("uq_config_user_team_key", "config", type_="unique") + except Exception: + pass # Constraint doesn't exist # Drop indexes - if "ix_config_team_id" in existing_index_names: + try: op.drop_index("ix_config_team_id", table_name="config") - if "ix_config_user_id" in existing_index_names: + except Exception: + pass + + try: op.drop_index("ix_config_user_id", table_name="config") - if "ix_config_key" in existing_index_names: + except Exception: + pass + + try: op.drop_index("ix_config_key", table_name="config") + except Exception: + pass - # Drop columns using raw SQL to avoid batch mode type inference issues - # SQLite doesn't support DROP COLUMN directly, so we recreate the table - if "team_id" in existing_columns or "user_id" in existing_columns: - # Create new table without user_id and team_id columns + # Drop columns - SQLite < 3.35.0 doesn't support DROP COLUMN, so recreate table + if connection.dialect.name == "sqlite": + # Recreate table without user_id and team_id columns for SQLite compatibility connection.execute( sa.text(""" CREATE TABLE config_new ( @@ -147,15 +135,24 @@ def downgrade() -> None: ) """) ) - # Copy data from old table to new table (only id, key, value columns) connection.execute(sa.text("INSERT INTO config_new (id, key, value) SELECT id, key, value FROM config")) - # Drop old table (this also drops all indexes) connection.execute(sa.text("DROP TABLE config")) - # Rename new table to original name connection.execute(sa.text("ALTER TABLE config_new RENAME TO config")) - # Recreate the original unique index on key (it was dropped with the old table) - op.create_index("ix_config_key", "config", ["key"], unique=True) else: - # If we're not dropping columns, just recreate the unique index on key + # PostgreSQL and modern SQLite support DROP COLUMN + try: + op.drop_column("config", "team_id") + except Exception: + pass + + try: + op.drop_column("config", "user_id") + except Exception: + pass + + # Recreate the original unique index on key + try: op.create_index("ix_config_key", "config", ["key"], unique=True) + except Exception: + pass # ### end Alembic commands ### diff --git a/api/alembic/versions/f278bbaa6f67_create_api_keys_table.py b/api/alembic/versions/f278bbaa6f67_create_api_keys_table.py index 8daae2a13..aa3edc72f 100644 --- a/api/alembic/versions/f278bbaa6f67_create_api_keys_table.py +++ b/api/alembic/versions/f278bbaa6f67_create_api_keys_table.py @@ -11,6 +11,7 @@ from alembic import op import sqlalchemy as sa +from transformerlab.db.migration_utils import table_exists # revision identifiers, used by Alembic. revision: str = "f278bbaa6f67" @@ -23,14 +24,7 @@ def upgrade() -> None: """Create api_keys table.""" connection = op.get_bind() - # Helper function to check if table exists - def table_exists(table_name: str) -> bool: - result = connection.execute( - sa.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:name"), {"name": table_name} - ) - return result.fetchone() is not None - - if not table_exists("api_keys"): + if not table_exists(connection, "api_keys"): op.create_table( "api_keys", sa.Column("id", sa.String(), nullable=False), @@ -39,7 +33,7 @@ def table_exists(table_name: str) -> bool: sa.Column("user_id", sa.String(), nullable=False), sa.Column("team_id", sa.String(), nullable=True), sa.Column("name", sa.String(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")), sa.Column("last_used_at", sa.DateTime(), nullable=True), sa.Column("expires_at", sa.DateTime(), nullable=True), sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), diff --git a/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py b/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py index 9ea8c184d..ff9f4b7bc 100644 --- a/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py +++ b/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py @@ -11,6 +11,8 @@ from alembic import op import sqlalchemy as sa +from transformerlab.db.migration_utils import table_exists + # revision identifiers, used by Alembic. revision: str = "f7661070ec23" down_revision: Union[str, Sequence[str], None] = None @@ -22,15 +24,8 @@ def upgrade() -> None: """Create all initial tables.""" connection = op.get_bind() - # Helper function to check if table exists - def table_exists(table_name: str) -> bool: - result = connection.execute( - sa.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:name"), {"name": table_name} - ) - return result.fetchone() is not None - # Config table - if not table_exists("config"): + if not table_exists(connection, "config"): op.create_table( "config", sa.Column("id", sa.Integer(), nullable=False), @@ -42,7 +37,7 @@ def table_exists(table_name: str) -> bool: op.create_index(op.f("ix_config_key"), "config", ["key"], unique=True) # Plugin table - if table_exists("plugins"): + if table_exists(connection, "plugins"): # Drop all indexes on the table op.drop_index(op.f("ix_plugins_name"), table_name="plugins", if_exists=True) op.drop_index(op.f("ix_plugins_type"), table_name="plugins", if_exists=True) @@ -61,7 +56,7 @@ def table_exists(table_name: str) -> bool: # op.create_index(op.f("ix_plugins_type"), "plugins", ["type"], unique=False) # TrainingTemplate table - if table_exists("training_template"): + if table_exists(connection, "training_template"): # Drop all indexes on the table op.drop_index(op.f("ix_training_template_name"), table_name="training_template", if_exists=True) op.drop_index(op.f("ix_training_template_created_at"), table_name="training_template", if_exists=True) @@ -89,7 +84,7 @@ def table_exists(table_name: str) -> bool: # op.create_index(op.f("ix_training_template_updated_at"), "training_template", ["updated_at"], unique=False) # Workflow table - if not table_exists("workflows"): + if not table_exists(connection, "workflows"): op.create_table( "workflows", sa.Column("id", sa.Integer(), nullable=False), @@ -105,7 +100,7 @@ def table_exists(table_name: str) -> bool: op.create_index("idx_workflow_id_experiment", "workflows", ["id", "experiment_id"], unique=False) # WorkflowRun table - if not table_exists("workflow_runs"): + if not table_exists(connection, "workflow_runs"): op.create_table( "workflow_runs", sa.Column("id", sa.Integer(), nullable=False), @@ -124,7 +119,7 @@ def table_exists(table_name: str) -> bool: op.create_index(op.f("ix_workflow_runs_status"), "workflow_runs", ["status"], unique=False) # Team table - if not table_exists("teams"): + if not table_exists(connection, "teams"): op.create_table( "teams", sa.Column("id", sa.String(), nullable=False), @@ -133,7 +128,7 @@ def table_exists(table_name: str) -> bool: ) # UserTeam table - if not table_exists("users_teams"): + if not table_exists(connection, "users_teams"): op.create_table( "users_teams", sa.Column("user_id", sa.String(), nullable=False), @@ -143,7 +138,7 @@ def table_exists(table_name: str) -> bool: ) # TeamInvitation table - if not table_exists("team_invitations"): + if not table_exists(connection, "team_invitations"): op.create_table( "team_invitations", sa.Column("id", sa.String(), nullable=False), @@ -166,16 +161,16 @@ def table_exists(table_name: str) -> bool: # User table (from fastapi-users) # Check if table exists first to avoid errors on existing databases - if not table_exists("user"): + if not table_exists(connection, "user"): # Create new user table with correct schema op.create_table( "user", sa.Column("id", sa.CHAR(length=36), nullable=False), sa.Column("email", sa.String(length=320), nullable=False), sa.Column("hashed_password", sa.String(length=1024), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("1")), - sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("0")), - sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")), + sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("false")), sa.Column("first_name", sa.String(length=100), nullable=True), sa.Column("last_name", sa.String(length=100), nullable=True), sa.PrimaryKeyConstraint("id"), @@ -203,9 +198,9 @@ def table_exists(table_name: str) -> bool: sa.Column("id", sa.CHAR(length=36), nullable=False), sa.Column("email", sa.String(length=320), nullable=False), sa.Column("hashed_password", sa.String(length=1024), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("1")), - sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("0")), - sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")), + sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("false")), sa.Column("first_name", sa.String(length=100), nullable=True), sa.Column("last_name", sa.String(length=100), nullable=True), sa.PrimaryKeyConstraint("id"), diff --git a/api/pyproject.toml b/api/pyproject.toml index 85c6c5a65..0a6a7e494 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -53,6 +53,7 @@ nvidia = [ # Packages with versions (using no-gpu/nvidia versions) "accelerate==1.3.0", "aiosqlite==0.20.0", + "asyncpg==0.30.0", "einops==0.8.0", "fastapi==0.115.7", "mcp[cli]==1.8.1", @@ -75,6 +76,7 @@ rocm = [ # Packages with ROCm-specific versions "accelerate==1.6.0", "aiosqlite==0.21.0", + "asyncpg==0.30.0", "einops==0.8.1", "fastapi==0.115.12", "mcp[cli]==1.9.1", @@ -96,6 +98,7 @@ cpu = [ # Packages with versions (using no-gpu/cpu versions) "accelerate==1.3.0", "aiosqlite==0.20.0", + "asyncpg==0.30.0", "einops==0.8.0", "fastapi==0.115.7", "mcp[cli]==1.8.1", diff --git a/api/transformerlab/db/constants.py b/api/transformerlab/db/constants.py index ee49ca60b..51fc1b6d1 100644 --- a/api/transformerlab/db/constants.py +++ b/api/transformerlab/db/constants.py @@ -2,7 +2,24 @@ import os from lab import HOME_DIR -db = None # This will hold the aiosqlite connection +db = None # This will hold the aiosqlite connection (for SQLite) or None (for PostgreSQL) DATABASE_FILE_NAME = f"{HOME_DIR}/llmlab.sqlite3" -# Allow DATABASE_URL to be overridden by environment variable (useful for testing) -DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}") + +# Check for PostgreSQL configuration via environment variables +DATABASE_HOST = os.getenv("DATABASE_HOST") +DATABASE_PORT = os.getenv("DATABASE_PORT", "5432") +DATABASE_DB = os.getenv("DATABASE_DB") +DATABASE_USER = os.getenv("DATABASE_USER") +DATABASE_PASSWORD = os.getenv("DATABASE_PASSWORD") + +# Construct DATABASE_URL based on available configuration +if DATABASE_HOST and DATABASE_DB and DATABASE_USER and DATABASE_PASSWORD: + # Use PostgreSQL if all required credentials are provided + DATABASE_URL = ( + f"postgresql+asyncpg://{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_DB}" + ) + DATABASE_TYPE = "postgresql" +else: + # Fall back to SQLite (default) + DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}") + DATABASE_TYPE = "sqlite" diff --git a/api/transformerlab/db/migration_utils.py b/api/transformerlab/db/migration_utils.py new file mode 100644 index 000000000..1b16137be --- /dev/null +++ b/api/transformerlab/db/migration_utils.py @@ -0,0 +1,46 @@ +""" +Utility functions for Alembic migrations. + +Keep this clean and isolated. Do NOT import Transformer Lab stuff in here. +""" + +import sqlalchemy as sa + + +def table_exists(connection, table_name: str) -> bool: + """ + Check if a table exists in the database. + + Supports SQLite, PostgreSQL, and a generic case for + other SQL databases via information_schema. + + Args: + connection: The database connection from op.get_bind() + table_name: The name of the table to check + + Returns: + bool: True if table exists, False otherwise + """ + dialect_name = connection.dialect.name + + if dialect_name == "sqlite": + # SQLite-specific query + result = connection.execute( + sa.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:name"), {"name": table_name} + ) + elif dialect_name == "postgresql": + # PostgreSQL-specific query + result = connection.execute( + sa.text("SELECT tablename FROM pg_tables WHERE schemaname='public' AND tablename=:name"), + {"name": table_name}, + ) + else: + # Fallback to standard information_schema (works for most databases) + result = connection.execute( + sa.text( + "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_name=:name" + ), + {"name": table_name}, + ) + + return result.fetchone() is not None diff --git a/api/transformerlab/db/session.py b/api/transformerlab/db/session.py index 39aa4b338..9f30ae79f 100644 --- a/api/transformerlab/db/session.py +++ b/api/transformerlab/db/session.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker -from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL +from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL, DATABASE_TYPE from lab.dirs import get_workspace_dir @@ -68,150 +68,78 @@ async def init(): Create the database, tables, and workspace folder if they don't exist. """ global db - # Migrate database from old location if necessary - old_db_base = os.path.join(await get_workspace_dir(), "llmlab.sqlite3") - if os.path.exists(old_db_base): - if not os.path.exists(DATABASE_FILE_NAME): - for ext in ["", "-wal", "-shm"]: - old_path = old_db_base + ext - new_path = DATABASE_FILE_NAME + ext - if os.path.exists(old_path): - shutil.copy2(old_path, new_path) - os.remove(old_path) - print("Migrated database from workspace to parent directory") - else: - for ext in ["", "-wal", "-shm"]: - old_path = old_db_base + ext - if os.path.exists(old_path): - os.remove(old_path) - print("Old database files removed (new database already exists)") - os.makedirs(os.path.dirname(DATABASE_FILE_NAME), exist_ok=True) - db = await aiosqlite.connect(DATABASE_FILE_NAME) - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA synchronous=normal") - await db.execute("PRAGMA busy_timeout = 30000") + + if DATABASE_TYPE == "sqlite": + # SQLite-specific initialization + # Migrate database from old location if necessary + old_db_base = os.path.join(await get_workspace_dir(), "llmlab.sqlite3") + if os.path.exists(old_db_base): + if not os.path.exists(DATABASE_FILE_NAME): + for ext in ["", "-wal", "-shm"]: + old_path = old_db_base + ext + new_path = DATABASE_FILE_NAME + ext + if os.path.exists(old_path): + shutil.copy2(old_path, new_path) + os.remove(old_path) + print("Migrated database from workspace to parent directory") + else: + for ext in ["", "-wal", "-shm"]: + old_path = old_db_base + ext + if os.path.exists(old_path): + os.remove(old_path) + print("Old database files removed (new database already exists)") + os.makedirs(os.path.dirname(DATABASE_FILE_NAME), exist_ok=True) + db = await aiosqlite.connect(DATABASE_FILE_NAME) + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA synchronous=normal") + await db.execute("PRAGMA busy_timeout = 30000") + else: + # PostgreSQL doesn't need aiosqlite connection or PRAGMA statements + db = None + print("Using PostgreSQL database") # Run Alembic migrations to create/update tables # This replaces the previous create_all() call await run_alembic_migrations() - # Check if workflow_runs table exists before checking/modifying columns - cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflow_runs'") - table_exists = await cursor.fetchone() - await cursor.close() - - if table_exists: - # Check if experiment_id column exists in workflow_runs table - cursor = await db.execute("PRAGMA table_info(workflow_runs)") - columns = await cursor.fetchall() + if DATABASE_TYPE == "sqlite": + # SQLite-specific: Check if workflow_runs table exists before checking/modifying columns + cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflow_runs'") + table_exists = await cursor.fetchone() await cursor.close() - has_experiment_id = any(column[1] == "experiment_id" for column in columns) - - if not has_experiment_id: - # Add experiment_id column - await db.execute("ALTER TABLE workflow_runs ADD COLUMN experiment_id INTEGER") - - # Update existing workflow runs with experiment_id from their workflows - await db.execute(""" - UPDATE workflow_runs - SET experiment_id = ( - SELECT experiment_id - FROM workflows - WHERE workflows.id = workflow_runs.workflow_id - ) - """) - await db.commit() + + if table_exists: + # Check if experiment_id column exists in workflow_runs table + cursor = await db.execute("PRAGMA table_info(workflow_runs)") + columns = await cursor.fetchall() + await cursor.close() + has_experiment_id = any(column[1] == "experiment_id" for column in columns) + + if not has_experiment_id: + # Add experiment_id column + await db.execute("ALTER TABLE workflow_runs ADD COLUMN experiment_id INTEGER") + + # Update existing workflow runs with experiment_id from their workflows + await db.execute(""" + UPDATE workflow_runs + SET experiment_id = ( + SELECT experiment_id + FROM workflows + WHERE workflows.id = workflow_runs.workflow_id + ) + """) + await db.commit() print("✅ Database initialized") - # Run migrations - await migrate_workflows_non_preserving() # await init_sql_model() return -async def migrate_workflows_non_preserving(): - """ - Migration function that renames workflows table as backup and creates new table - based on current schema definition if experiment_id is not INTEGER type or config is not JSON type - """ - - try: - # Check if workflows table exists - cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflows'") - table_exists = await cursor.fetchone() - await cursor.close() - - if not table_exists: - print("Workflows table does not exist. Skipping non-preserving migration.") - return - - # Check column types in the current workflows table - cursor = await db.execute("PRAGMA table_info(workflows)") - columns_info = await cursor.fetchall() - await cursor.close() - - experiment_id_type = None - config_type = None - - for column in columns_info: - column_name = column[1] - column_type = column[2].upper() - - if column_name == "experiment_id": - experiment_id_type = column_type - elif column_name == "config": - config_type = column_type - - # Check if migration is needed based on column types - needs_migration = False - migration_reasons = [] - - if experiment_id_type and experiment_id_type != "INTEGER": - needs_migration = True - migration_reasons.append(f"experiment_id column type is {experiment_id_type}, expected INTEGER") - - # SQLAlchemy JSON type maps to TEXT in SQLite, so we accept both - if config_type and config_type not in ["JSON", "TEXT"]: - needs_migration = True - migration_reasons.append( - f"config column type is {config_type}, expected JSON/TEXT (SQLAlchemy creates JSON as TEXT in SQLite)" - ) - - if not needs_migration: - # print("Column types are correct. No migration needed.") - return - - print("Migration needed due to:") - for reason in migration_reasons: - print(f" - {reason}") - - # Check if backup table already exists and drop it - cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflows_backup'") - backup_exists = await cursor.fetchone() - await cursor.close() - - if backup_exists: - await db.execute("DROP TABLE workflows_backup") - - # Rename current table as backup - await db.execute("ALTER TABLE workflows RENAME TO workflows_backup") - - # Note: Table creation is now handled by Alembic migrations - # If we need to recreate the workflows table, it should be done via a migration - pass - - await db.commit() - print("Successfully created new workflows table with correct schema. Old table saved as workflows_backup.") - - except Exception as e: - print(f"Failed to perform non-preserving migration: {e}") - raise e - - async def close(): - await db.close() + if DATABASE_TYPE == "sqlite" and db is not None: + await db.close() await async_engine.dispose() print("✅ Database closed") return diff --git a/api/transformerlab/routers/auth/api_key_auth.py b/api/transformerlab/routers/auth/api_key_auth.py index 6e8aae05b..4169cb2b6 100644 --- a/api/transformerlab/routers/auth/api_key_auth.py +++ b/api/transformerlab/routers/auth/api_key_auth.py @@ -1,5 +1,6 @@ """API Key authentication helpers.""" +import uuid from fastapi import Request, HTTPException from fastapi.security import HTTPBearer from sqlalchemy.ext.asyncio import AsyncSession @@ -84,7 +85,7 @@ async def validate_api_key_and_get_user( raise HTTPException(status_code=401, detail="API key has expired") # Get the user - stmt = select(User).where(User.id == api_key_obj.user_id) + stmt = select(User).where(User.id == uuid.UUID(api_key_obj.user_id)) result = await session.execute(stmt) user = result.unique().scalar_one_or_none() diff --git a/api/transformerlab/routers/quota.py b/api/transformerlab/routers/quota.py index 0af894c2e..185fd8299 100644 --- a/api/transformerlab/routers/quota.py +++ b/api/transformerlab/routers/quota.py @@ -1,5 +1,6 @@ """Router for managing quota tracking and enforcement.""" +import uuid from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel, Field @@ -148,7 +149,7 @@ async def get_team_quota_usage_by_users( user_id_str = member.user_id # Get user details - user_stmt = select(User).where(User.id == user_id_str) + user_stmt = select(User).where(User.id == uuid.UUID(user_id_str)) user_result = await session.execute(user_stmt) user = user_result.scalar_one_or_none() if not user: diff --git a/api/transformerlab/routers/teams.py b/api/transformerlab/routers/teams.py index 5c4678631..c0283bbbb 100644 --- a/api/transformerlab/routers/teams.py +++ b/api/transformerlab/routers/teams.py @@ -1,3 +1,4 @@ +import uuid from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query from fastapi.responses import FileResponse, Response from sqlalchemy.ext.asyncio import AsyncSession @@ -623,7 +624,7 @@ async def get_my_invitations( team = result.scalar_one_or_none() # Get inviter info - stmt = select(User).where(User.id == invitation.invited_by_user_id) + stmt = select(User).where(User.id == uuid.UUID(invitation.invited_by_user_id)) result = await session.execute(stmt) inviter = result.scalar_one_or_none() @@ -843,7 +844,7 @@ async def get_team_invitations( invitation.status = InvitationStatus.EXPIRED.value # Get inviter info - stmt = select(User).where(User.id == invitation.invited_by_user_id) + stmt = select(User).where(User.id == uuid.UUID(invitation.invited_by_user_id)) result = await session.execute(stmt) inviter = result.scalar_one_or_none() diff --git a/api/transformerlab/shared/api_key_auth.py b/api/transformerlab/shared/api_key_auth.py index f3ff949e5..57877c240 100644 --- a/api/transformerlab/shared/api_key_auth.py +++ b/api/transformerlab/shared/api_key_auth.py @@ -1,5 +1,6 @@ """API Key authentication helpers.""" +import uuid from fastapi import Request, HTTPException from fastapi.security import HTTPBearer from sqlalchemy.ext.asyncio import AsyncSession @@ -82,7 +83,7 @@ async def validate_api_key_and_get_user( raise HTTPException(status_code=401, detail="API key has expired") # Get the user - stmt = select(User).where(User.id == api_key_obj.user_id) + stmt = select(User).where(User.id == uuid.UUID(api_key_obj.user_id)) result = await session.execute(stmt) user = result.unique().scalar_one_or_none() diff --git a/api/transformerlab/shared/models/user_model.py b/api/transformerlab/shared/models/user_model.py index 88cc21933..3e78be3da 100644 --- a/api/transformerlab/shared/models/user_model.py +++ b/api/transformerlab/shared/models/user_model.py @@ -4,12 +4,11 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy import select from fastapi_users.db import SQLAlchemyUserDatabase -from sqlalchemy.dialects.sqlite import insert from fastapi import Depends from os import getenv import uuid -from transformerlab.db.constants import DATABASE_URL +from transformerlab.db.constants import DATABASE_URL, DATABASE_TYPE from transformerlab.shared.models.models import Team, User, OAuthAccount from transformerlab.shared.remote_workspace import create_bucket_for_team @@ -72,15 +71,30 @@ async def add_oauth_account(self, user, create_dict: dict): if not user_exists: raise ValueError(f"User with id {user.id} does not exist") - # Perform an upsert: insert if not exists, update if conflict on unique constraint - stmt = ( - insert(OAuthAccount) - .values(user_id=user.id, **create_dict) - .on_conflict_do_update( - index_elements=["oauth_name", "account_id"], # Unique index on these columns - set_={k: v for k, v in create_dict.items() if k not in ["id"]}, # Update all fields except primary key + # Perform an upsert: insert if not exists, update if conflict on unique constraint. + # Must use dialect-specific insert for ON CONFLICT DO UPDATE support. + if DATABASE_TYPE == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + stmt = ( + pg_insert(OAuthAccount) + .values(user_id=user.id, **create_dict) + .on_conflict_do_update( + index_elements=["oauth_name", "account_id"], + set_={k: v for k, v in create_dict.items() if k not in ["id"]}, + ) + ) + else: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + stmt = ( + sqlite_insert(OAuthAccount) + .values(user_id=user.id, **create_dict) + .on_conflict_do_update( + index_elements=["oauth_name", "account_id"], + set_={k: v for k, v in create_dict.items() if k not in ["id"]}, + ) ) - ) await self.session.execute(stmt) return user