#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Test script to demonstrate wakeable producer poll() interruptibility with real Kafka broker.
#
# This script tests the wakeable poll pattern by:
# 1. Connecting to a real Kafka broker
# 2. Producing messages to a topic
# 3. Calling poll() with a long timeout (or infinite)
# 4. Allowing Ctrl+C to interrupt the poll operation
#
# Usage:
#   python test_wakeable_producer_poll_interrupt.py [bootstrap_servers] [topic]
#
# Example:
#   python test_wakeable_producer_poll_interrupt.py localhost:9092 test-topic
#
# Press Ctrl+C to test interruptibility
#
# NOTE: This script automatically uses the local development version of confluent_kafka
#       by setting PYTHONPATH to include the src directory.

import os
import sys
import time

# Ensure we use the local development version
script_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(script_dir, 'src')
if os.path.exists(src_dir) and src_dir not in sys.path:
    sys.path.insert(0, src_dir)

from confluent_kafka import Producer, KafkaException
from confluent_kafka.admin import AdminClient, NewTopic

# Verify we're using the local development version
import confluent_kafka
if 'src/confluent_kafka' in confluent_kafka.__file__ or script_dir in confluent_kafka.__file__:
    print(f"✓ Using local development version: {confluent_kafka.__file__}")
else:
    print(f"⚠ WARNING: Using installed version: {confluent_kafka.__file__}")
    print("  The wakeable poll changes may not be active!")
    print("  Make sure to build the local version with: python setup.py build_ext --inplace")
    print()

# Default configuration
DEFAULT_BOOTSTRAP_SERVERS = os.environ.get('BOOTSTRAP_SERVERS', 'localhost:9092')
DEFAULT_TOPIC = os.environ.get('TEST_TOPIC', 'test-wakeable-producer-poll-topic')


def main():
    # Parse command line arguments
    bootstrap_servers = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_BOOTSTRAP_SERVERS
    topic = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_TOPIC
    
    print("=" * 70)
    print("Wakeable Producer Poll Interruptibility Test")
    print("=" * 70)
    print(f"Bootstrap servers: {bootstrap_servers}")
    print(f"Topic: {topic}")
    print()
    print("This test demonstrates the wakeable poll pattern:")
    print("  - poll() will be called with infinite timeout (blocks until Ctrl+C)")
    print("  - The operation can be interrupted with Ctrl+C")
    print("  - With the wakeable pattern, Ctrl+C should interrupt within ~200ms")
    print()
    print("Press Ctrl+C to test interruptibility...")
    print("=" * 70)
    print()
    
    # Create topic if it doesn't exist
    admin_client = None
    try:
        print(f"Ensuring topic '{topic}' exists...")
        admin_conf = {'bootstrap.servers': bootstrap_servers}
        admin_client = AdminClient(admin_conf)
        
        # Try to create the topic
        new_topic = NewTopic(topic, num_partitions=1, replication_factor=1)
        fs = admin_client.create_topics([new_topic], request_timeout=10.0)
        
        # Wait for topic creation
        for topic_name, f in fs.items():
            try:
                f.result(timeout=10.0)
                print(f"✓ Topic '{topic_name}' created successfully")
            except Exception as e:
                if "already exists" in str(e).lower() or "TopicExistsException" in str(type(e).__name__):
                    print(f"✓ Topic '{topic_name}' already exists")
                else:
                    print(f"⚠ Could not create topic '{topic_name}': {e}")
                    print("  Continuing anyway - topic may already exist...")
    except Exception as e:
        print(f"⚠ Could not create topic: {e}")
        print("  Continuing anyway - topic may already exist...")
    finally:
        if admin_client:
            admin_client = None  # Clean up admin client
    
    # Producer configuration
    # Use acks=all to ensure we wait for acknowledgments, making poll() block longer
    conf = {
        'bootstrap.servers': bootstrap_servers,
        'socket.timeout.ms': 100,
        'acks': 'all',  # Wait for all acknowledgments - ensures poll() blocks waiting for callbacks
    }
    
    producer = None
    try:
        # Create producer
        producer = Producer(conf)
        
        # Produce multiple messages to ensure there are callbacks to wait for
        print(f"Producing test messages to topic: {topic}")
        num_messages = 10
        for i in range(num_messages):
            try:
                producer.produce(topic, 
                                value=f'test-message-{i}'.encode(), 
                                key=f'test-key-{i}'.encode())
            except Exception as e:
                print(f"⚠ Could not produce message {i+1}: {e}")
        
        print(f"✓ Produced {num_messages} messages")
        print("  (Using acks=all to ensure acknowledgments are waited for)")
        print()
        
        # Poll once to clear any immediate callbacks
        print("Polling once to clear immediate callbacks...")
        events = producer.poll(timeout=0.1)
        print(f"  Cleared {events} immediate events")
        print()
        
        print("Ready! Starting poll() with infinite timeout (will block until Ctrl+C)...")
        print("  (poll() will block waiting for delivery callbacks)")
        print()
        
        # Test poll() with infinite timeout - this should be interruptible
        start_time = time.time()
        try:
            print(f"[{time.strftime('%H:%M:%S')}] Calling poll() with infinite timeout...")
            print("    (This will block until delivery callbacks arrive or Ctrl+C)")
            print("    With acks=all, callbacks may take longer to arrive")
            print()
            
            events_processed = producer.poll(timeout=-1.0)  # Infinite timeout - will block
            
            elapsed = time.time() - start_time
            
            print(f"[{time.strftime('%H:%M:%S')}] poll() returned")
            print(f"    Events processed: {events_processed}")
            print(f"    Elapsed time: {elapsed:.2f}s")
        
        except KeyboardInterrupt:
            elapsed = time.time() - start_time
            print()
            print("=" * 70)
            print("✓ KeyboardInterrupt caught!")
            print(f"  Interrupted after {elapsed:.2f} seconds")
            print(f"  With wakeable pattern, interruption should occur within ~200ms of Ctrl+C")
            if elapsed < 0.5:
                print("  ✓ Fast interruption confirmed!")
            else:
                print(f"  ⚠ Interruption took {elapsed:.2f}s (may indicate wakeable pattern issue)")
            print("=" * 70)
            raise  # Re-raise to exit
    
    except KafkaException as e:
        print(f"Kafka error: {e}")
        sys.exit(1)
    
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    finally:
        if producer:
            print()
            print("Closing producer...")
            producer.flush(timeout=1.0)  # Flush any remaining messages
            producer.close()
            print("Producer closed.")


if __name__ == '__main__':
    main()

