Web Sockets with API Gateway and AWS Lambda

Web Sockets with API Gateway and AWS Lambda

Featured on Hashnode

AWS offers an easy way to work with web sockets through the WebSockets API in API Gateway. This article is going to be a short walkthrough on how to get started with WebSocket APIs.

Setup

This is article is based on a serverless framework project. You can still implement the concepts discussed here without serverless framework.

You can find the code used in this article in this GitHub repository.

Serverless provider config

The serverless provider configuration will look exactly the same as a regular API project with the exception of one additional property: websocketsApiRouteSelectionExpression.

This property tells API gateway which request body property to use in order to determine where to route the request. In this instance, the action property of the payload will determine which handler gets called.

We'll see this in action when we define the handlers.

provider:
  name: aws
  runtime: nodejs14.x
  profile: ${env:PROFILE, "localstack"}
  stage: ${opt:stage, "local"}
  region: ${opt:region, "us-east-1"}
  websocketsApiRouteSelectionExpression: $request.body.action

Connection table

For our use case, we'll need to maintain a record of active connections in order to send targeted messages.

I've chosen DynamoDB for this demonstration but you might get better mileage using a relational database like PostgreSQL. The concept will be the same, but with a slightly different implementation.

A table row will contain Username and ConnectionId fields for each active connection.

We must create a DynamoDB table resource in the serverless.yaml file. The table will use the Username as the primary key so make sure it's unique.

Since we'll need to query the table by ConnectionId at some point, let's create a global secondary index that uses ConnectionId as the primary key.

You can read more about global secondary indexes here.

resources:  
  Resources:
    ConnectionsTable:
      Type: AWS::DynamoDB::Table
      Properties:
        AttributeDefinitions:
          - AttributeName: Username
            AttributeType: S
          - AttributeName: ConnectionId
            AttributeType: S
        KeySchema:
          - AttributeName: Username
            KeyType: HASH
        ProvisionedThroughput:
          ReadCapacityUnits: 1
          WriteCapacityUnits: 1
        GlobalSecondaryIndexes:
          - IndexName: ConnectionIdIndex
            KeySchema:
              - AttributeName: ConnectionId
                KeyType: HASH
            Projection:
              ProjectionType: ALL
            ProvisionedThroughput:
                ReadCapacityUnits: 1
                WriteCapacityUnits: 1

Connecting

Now let's create the connection handler. This lambda will be called when a client tries to establish a connection.

The main aim of this handler is to save the connection details in the DynamoDB table so this connection can be targeted when sending a message.

First, let's define the function in serverless.yaml in the function section:

wsConnectHandler:
    handler: ./src/connect.handler
    events:
      - websocket:
          route: $connect
    iamRoleStatements:
      - Effect: "Allow"
        Action: "dynamodb:PutItem"
        Resource: !GetAtt [ConnectionsTable, Arn]

This lambda is triggered by the pre-defined $connect websocket route. We also give this function permission to put an item into the DynamoDB table we defined earlier.

The handler logic looks as follows:

import {
  DynamoDBClient,
  PutItemCommand
} from "@aws-sdk/client-dynamodb";

const {
  STAGE,
  REGION,
  CONNECTION_TABLE,
  LOCALSTACK_ENDPOINT
} = process.env;

const getDynamoDbConfig = (): { region: string, endpoint?: string } => {
  if (STAGE === "local") return { region: REGION, endpoint: LOCALSTACK_ENDPOINT };
  return { region: REGION };
}

module.exports.handler = async (event) => {
  const dynamoDBClient = new DynamoDBClient(getDynamoDbConfig());

  const putItemCommand = new PutItemCommand({
    TableName: CONNECTION_TABLE,
    Item: {
      Username: { S: event.queryStringParameters.username },
      ConnectionId: { S: event.requestContext.connectionId }
    }
  });
  await dynamoDBClient.send(putItemCommand);

  return { statusCode: 200 };
}

When the connection route is triggered, we create a record in the connection table that contains the username and connection id.

We retrieve the username from the queryStringParameters of the request. To establish a connection, the client will have to send a connection request to wss://<api-id>.execute-api.<region>.amazonaws.com/<stage>/?username=<username>.

If the username isn't provided, an error is thrown and a connection isn't established. We must return a 200 status code for a successful connection.

You can implement custom error handling here if you'd like but AWS will usually display the error in the cloudwatch logs when encountered.

Disconnecting

Now let's handle disconnection. The disconnection handler is triggered whenever a connection is ended. We can use this to do some cleanup. In our case, the cleanup involves deleting the connection row in our DynamoDB table.

First, let's define the function in serverless.yaml

wsDisconnectHandler:
    handler: ./src/disconnect.handler
    events:
      - websocket:
          route: $disconnect
    iamRoleStatements:
      - Effect: "Allow"
        Action: 
          - "dynamodb:DeleteItem"
          - "dynamodb:Query"
        Resource: 
          - !GetAtt [ConnectionsTable, Arn]
          - !Join ["/", [!GetAtt [ConnectionsTable, Arn], "index", "ConnectionIdIndex"]]

Once again, this is triggered by the disconnect event through the pre-defined $disconnect route.

We need to give this function permission to delete from the connection table and query the global secondary index we created earlier.

Here's the handler implementation:

import {
  DynamoDBClient,
  DeleteItemCommand,
  QueryCommand,
} from "@aws-sdk/client-dynamodb";

const {
  STAGE,
  REGION,
  CONNECTION_TABLE,
  LOCALSTACK_ENDPOINT
} = process.env;

const getDynamoDbConfig = (): { region: string, endpoint?: string } => {
  if (STAGE === "local") return { region: REGION, endpoint: LOCALSTACK_ENDPOINT };
  return { region: REGION };
}

module.exports.handler = async (event, context, callback) => {
  console.log(event);

  const dynamoDBClient = new DynamoDBClient(getDynamoDbConfig());

  // Retrieve the connection using the connectionId
  const queryCommand = new QueryCommand({
    TableName: CONNECTION_TABLE,
    IndexName: "ConnectionIdIndex",
    KeyConditionExpression: "ConnectionId = :a",
    ExpressionAttributeValues: {
      ":a": { S: event.requestContext.connectionId }
    }
  })

  const queryResult = await dynamoDBClient.send(queryCommand);
  const connectionRecord = queryResult?.Items ? queryResult.Items[0] : undefined;

  const deleteItemCommand = new DeleteItemCommand({
    TableName: CONNECTION_TABLE,
    Key: {
      Username: { S : connectionRecord?.Username?.S ? connectionRecord?.Username?.S : "" }
    }
  });
  await dynamoDBClient.send(deleteItemCommand);

  return {
    statusCode: 200,
    body: JSON.stringify({ message: "Connection ended!" })
  };
}

We retrieve the connection id from the request context in the event object and query the index to find the connection record with this connection id.

Once we've retrieved the record, we delete it from the connection table using the username as the key.

Send a message

Now that we've handled connecting and disconnecting, it's time to handle sending messages directly to another connected user.

This handler will work by receiving a payload that contains the recipient's username, finding the recipient's connection record, extracting the connection id, and sending the payload message to the target connection.

First, let's define the function in serverless.yaml:

wsSendMessage:
    handler: ./src/sendMessage.handler
    events:
      - websocket:
          route: "sendMessage"
    iamRoleStatements:
      - Effect: "Allow"
        Action: "dynamodb:Query"
        Resource:
          - !GetAtt [ConnectionsTable, Arn]
          - !Join ["/", [!GetAtt [ConnectionsTable, Arn], "index", "ConnectionIdIndex"]]
      - Effect: "Allow"
        Action: "execute-api:ManageConnections"
        Resource:
          - Fn::Join:
            - "/"
            -
              - Fn::Join:
                - ":"
                -
                  - "arn:aws:execute-api"
                  - ${opt:region, "us-east-1"}
                  - !Ref AWS::AccountId
                  - !Ref WebsocketsApi
              - ${opt:stage, 'dev'}
              - "POST"
              - "@connections"
              - "{connectionId}"

Notice we have a slightly different event. The route is set as sendMessage. This is a custom route defined by us.

If you recall, the provider section has this property defined: websocketsApiRouteSelectionExpression: $request.body.action. This tells API Gateway to look for the route in the action property of the request body.

So in order to trigger the wsSendMessage lambda, the message we send once connected needs to be an object that contains the action property with the value of "sendMessage" like so:

{
  "action": "sendMessage",
  ...
}

Next, we need to give this lambda permission to query both the connection table and the global secondary index. We also must give it permission to send a message to an existing web socket connection.

This is the handler implementation:

import { DynamoDBClient, QueryCommand } from "@aws-sdk/client-dynamodb";

import {
  ApiGatewayManagementApiClient,
  PostToConnectionCommand
} from "@aws-sdk/client-apigatewaymanagementapi";

const {
  STAGE,
  REGION,
  CONNECTION_TABLE,
  LOCALSTACK_ENDPOINT,
  PORT
} = process.env;

const getDynamoDbConfig = (): { region: string, endpoint?: string } => {
  if (STAGE === "local") return { region: REGION, endpoint: LOCALSTACK_ENDPOINT };
  return { region: REGION };
}

const getConnectionEndpoint = (event) => {
  return STAGE === "local" ? 
  `http://${event.requestContext.domainName}:${PORT}` :
  `https://${event.requestContext.domainName}/${event.requestContext.stage}/`
}

module.exports.handler = async (event, context, callback) => {
  console.log(event);

  const body = JSON.parse(event.body);

  const dynamoDBClient = new DynamoDBClient(getDynamoDbConfig());

  /**
   * Query the connection table for the TARGET connection record.
   * The recepient's username used in the query will be included in the event body.
   */
  let queryResult = await dynamoDBClient.send(new QueryCommand({
    TableName: CONNECTION_TABLE,
    KeyConditionExpression: "Username = :a",
    ExpressionAttributeValues: {
      ":a": { S: body.recepient}
    }
  }));

  // Return early if the recepient is not found in the connections table.
  if (!(queryResult?.Items?.length && queryResult?.Items?.length > 0)) return callback(JSON.stringify({
    message: "Recepient not found",
    body
  }));

  const recepientConnection =  queryResult.Items[0];

  /**
   * Query ConnectionIdIndex using the connection id of the sender.
   * The senders connection id can be found in event.requestContext.connectionId.
   * We retrieve this connection record in order to retrieve the sender's username.
   */
  queryResult = await dynamoDBClient.send(new QueryCommand({
    TableName: CONNECTION_TABLE,
    IndexName: "ConnectionIdIndex",
    KeyConditionExpression: "ConnectionId = :b",
    ExpressionAttributeValues: {
      ":b": { S: event.requestContext.connectionId }
    },
  }));

  // Return early if no actively connected sender is found.
  if (!(queryResult?.Items?.length && queryResult?.Items?.length > 0)) return callback(JSON.stringify({
    message: "Sender not found",
    context: event.requestContext
  }));

  const senderConnection = queryResult?.Items ? queryResult.Items[0]: undefined;

  const apiGatewayManagementApiClient = new ApiGatewayManagementApiClient({ 
    region: REGION,
    endpoint: getConnectionEndpoint(event)
  });

  // Post a message to the recipient's connection.
  const postToConnectionCommand = new PostToConnectionCommand({
    ConnectionId: recepientConnection?.ConnectionId.S,
    Data: Buffer.from(JSON.stringify({
      from: senderConnection?.Username.S,
      message: body.message
    }), "utf-8")
  });

try {
    await apiGatewayManagementApiClient.send(postToConnectionCommand);
  } catch (error) {
    console.log(error);
  }

  return { 
    statusCode: 200,
    body: JSON.stringify({ message: "Message sent." })
  };
}

In order to test this, send a message to the connection with the following signature:

{
  "action": "sendMessage",
  "recepient": <username>,
  "message": <string>
}

Make sure the recipient is connected as well.

For more in-depth reading on WebSocket APIs, visit the official AWS docs here.