ssl - Optimizing Async Python Connection Pool with Transport Layer: Best Practices for TLS, Concurrency, and Dynamic Client Hand

I've implemented an async connection pool (ConnectionPool) and transport layer (Transport) in Pyth

I've implemented an async connection pool (ConnectionPool) and transport layer (Transport) in Python for managing persistent connections to a backend service. The goal is to handle multiple clients efficiently with sticky sessions, timeouts, and connection reuse. However, I'm unsure about:

TLS/SSL Integration:

Currently, the code doesn't implement TLS. How should I properly add TLS encryption to the Transport class (e.g., using asyncio.open_connection with SSL contexts)?

Are there security pitfalls (e.g., certificate verification, protocol versions) to avoid?

Connection Pool Optimization:

The pool uses a Semaphore for max connections and sticky sessions via sticky_map. Is this approach thread-safe and scalable for high concurrency?

How can I improve connection reuse (e.g., health checks, idle timeout)?

Dynamic Client Handling:

The current design assumes a single backend address. How can I extend it to support dynamic endpoints (e.g., load balancing across multiple hosts)?

Should I pre-warm connections or implement lazy initialization?

Error Handling:

Are there critical edge cases (e.g., partial writes, zombie connections) that aren’t handled robustly?

Code Reference: ConnectionPool: Manages connections with sticky sessions and semaphore-based limits.

Transport: Handles low-level socket communication with retries and timeouts. Connection_pool.py

import asyncio
import logging
from transport import Transport

logger = logging.getLogger(__name__)

class ConnectionPool:
    def __init__(self, address, max_connections=100, timeout=5):
        self.address = address
        self.max_connections = max_connections
        self.timeout = timeout
        self.connections = set()  # All active connections (alive or dead)
        self.available_connections = asyncio.Queue()  # Connections ready for reuse
        self.sticky_map = {}  # request_id -> Transport (lazily cleaned)
        self.semaphore = asyncio.Semaphore(max_connections)  # Limits concurrency

    async def _create_connection(self):
        """Create and return a new connection (internal)."""
        host, port = self.address.split(":")
        conn = Transport(host, int(port), self.timeout)
        try:
            if await conn.connect():
                self.connections.add(conn)
                return conn
        except Exception as e:
            logger.error(f"Connection failed: {e}")
        return None

    async def get_connection(self, request_id=None, sticky=False):
        """Get a connection, reusing sticky/available ones or creating new."""
        # Fast path: Reuse sticky connection if valid
        if sticky and request_id:
            async with asyncio.Lock():
                if request_id in self.sticky_map:
                    conn = self.sticky_map[request_id]
                    if conn.is_alive():
                        return conn
                    del self.sticky_map[request_id]  # Cleanup dead sticky

        # Reuse available connections
        while not self.available_connections.empty():
            conn = await self.available_connections.get()
            if conn.is_alive():
                if sticky and request_id:
                    async with asyncio.Lock():
                        self.sticky_map[request_id] = conn
                return conn
            else:
                async with asyncio.Lock():
                    self.connections.discard(conn)  # Remove dead

        # Create a new connection if under limit
        async with self.semaphore:
            conn = await self._create_connection()
            if conn and sticky and request_id:
                async with asyncio.Lock():
                    self.sticky_map[request_id] = conn
            return conn

    async def release_connection(self, conn, request_id=None):
        """Release a connection back to the pool."""
        if request_id:
            async with asyncio.Lock():
                self.sticky_map.pop(request_id, None)  # Unstick if needed

        if conn.is_alive():
            await self.available_connections.put(conn)  # Reuse alive connections
        else:
            async with asyncio.Lock():
                self.connections.discard(conn)  # Remove dead
        self.semaphore.release()  # Important: Release after cleanup!

    async def send_request(self, request, sticky=False):
        """Send a request using the pool (handles timeouts and errors)."""
        await self.semaphore.acquire()
        conn = None
        try:
            conn = await self.get_connection(request.id, sticky)
            if not conn:
                return {"error": "No connection available."}

            response = await asyncio.wait_for(
                conn.send_request(request),
                timeout=self.timeout
            )
            return response
        except asyncio.TimeoutError:
            logger.error("Request timed out.")
            return {"error": "Timeout"}
        except Exception as e:
            logger.error(f"Request failed: {e}")
            return {"error": str(e)}
        finally:
            if conn:
                await self.release_connection(conn, request.id if sticky else None)

    async def close_all(self):
        """Close all connections and reset the pool."""
        async with asyncio.Lock():
            for conn in self.connections:
                await conn.close()
            self.connections.clear()
            self.sticky_map.clear()
            while not self.available_connections.empty():
                self.available_connections.get_nowait()

Transport.py

import asyncio
import struct
import logging
import response_pb2 as pb2_response

logger = logging.getLogger(__name__)

class Transport:
    def __init__(self, host, port, timeout=5):
        self.host = host
        self.port = port
        self.timeout = timeout
        self.reader = None
        self.writer = None
        self.lock = asyncio.Lock()  # Per-connection lock
        self.connected = False

    async def connect(self):
        """Establish a connection with retries and exponential backoff."""
        retry_delay = 1
        for _ in range(3):
            try:
                self.reader, self.writer = await asyncio.wait_for(
                    asyncio.open_connection(self.host, self.port),
                    timeout=self.timeout
                )
                self.connected = True
                logger.info(f"Connected to {self.host}:{self.port}")
                return True
            except (asyncio.TimeoutError, ConnectionRefusedError) as e:
                logger.error(f"Connection failed: {e}")
                break
            except Exception as e:
                logger.warning(f"Retrying connection: {e}")
                await asyncio.sleep(retry_delay)
                retry_delay *= 2

        self.connected = False
        return False

    def is_alive(self):
        """Check if the connection is active."""
        return (
            self.connected 
            and self.writer 
            and not self.writer.is_closing()
        )

    async def send_request(self, request):
        """Thread-safe request sending with timeout."""
        if not self.is_alive() and not await self.connect():
            return {"error": "Connection failed."}

        async with self.lock:  # Ensure only one coroutine uses this connection
            try:
                # Serialize and send request
                req_data = request.SerializeToString()
                req_len = struct.pack(">I", len(req_data))
                self.writer.write(req_len + req_data)
                await self.writer.drain()
                
                # Await response
                return await asyncio.wait_for(
                    self.receive_response(),
                    timeout=self.timeout
                )
            except Exception as e:
                logger.error(f"Request failed: {e}")
                self.connected = False
                return {"error": str(e)}

    async def receive_response(self):
        """Receive and parse a protobuf response."""
        try:
            len_buf = await self.reader.readexactly(4)
            resp_len = struct.unpack(">I", len_buf)[0]
            response_data = await self.reader.readexactly(resp_len)
            response = pb2_response.Response()
            response.ParseFromString(response_data)
            return response
        except Exception as e:
            self.connected = False
            return {"error": f"Receiver error: {e}"}

    async def close(self):
        """Gracefully close the connection."""
        if self.writer:
            self.writer.close()
            await self.writer.wait_closed()
        self.connected = False

Observe the codes, implement the changes and give me the updated code files

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744203705a4563002.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信