diff --git a/cmd/api.go b/cmd/api.go index e2e8c4d..0fb1245 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -86,6 +86,9 @@ func RunApi(cmd *cobra.Command, args []string) { // token balance queries root.GET("/balances/:owner/:type", handlers.GetTokenBalancesByType) + + // token holder queries + root.GET("/holders/:address", handlers.GetTokenHoldersByType) } r.GET("/health", func(c *gin.Context) { diff --git a/internal/handlers/token_handlers.go b/internal/handlers/token_handlers.go index 3ebe86f..ea735f3 100644 --- a/internal/handlers/token_handlers.go +++ b/internal/handlers/token_handlers.go @@ -19,6 +19,12 @@ type BalanceModel struct { Balance *big.Int `json:"balance" ch:"balance"` } +type HolderModel struct { + HolderAddress string `json:"holder_address" ch:"owner"` + TokenId string `json:"token_id" ch:"token_id"` + Balance *big.Int `json:"balance" ch:"balance"` +} + // @Summary Get token balances of an address by type // @Description Retrieve token balances of an address by type // @Tags balances @@ -125,3 +131,106 @@ func serializeBalance(balance common.TokenBalance) BalanceModel { }(), } } + +// @Summary Get holders of a token +// @Description Retrieve holders of a token +// @Tags holders +// @Accept json +// @Produce json +// @Security BasicAuth +// @Param chainId path string true "Chain ID" +// @Param address path string true "Address of the token" +// @Param token_type path string false "Type of token" +// @Param hide_zero_balances query bool true "Hide zero balances" +// @Param page query int false "Page number for pagination" +// @Param limit query int false "Number of items per page" default(5) +// @Success 200 {object} api.QueryResponse{data=[]LogModel} +// @Failure 400 {object} api.Error +// @Failure 401 {object} api.Error +// @Failure 500 {object} api.Error +// @Router /{chainId}/holders/{address} [get] +func GetTokenHoldersByType(c *gin.Context) { + chainId, err := api.GetChainId(c) + if err != nil { + api.BadRequestErrorHandler(c, err) + return + } + + address := strings.ToLower(c.Param("address")) + if !strings.HasPrefix(address, "0x") { + api.BadRequestErrorHandler(c, fmt.Errorf("invalid address '%s'", address)) + return + } + + tokenType := c.Query("token_type") + if tokenType != "" && tokenType != "erc20" && tokenType != "erc1155" && tokenType != "erc721" { + api.BadRequestErrorHandler(c, fmt.Errorf("invalid token type '%s'", tokenType)) + return + } + hideZeroBalances := c.Query("hide_zero_balances") != "false" + + columns := []string{"owner", "sum(balance) as balance"} + groupBy := []string{"owner"} + if tokenType != "erc20" { + columns = []string{"owner", "token_id", "sum(balance) as balance"} + groupBy = []string{"owner", "token_id"} + } + + qf := storage.BalancesQueryFilter{ + ChainId: chainId, + TokenType: tokenType, + TokenAddress: address, + ZeroBalance: hideZeroBalances, + GroupBy: groupBy, + SortBy: c.Query("sort_by"), + SortOrder: c.Query("sort_order"), + Page: api.ParseIntQueryParam(c.Query("page"), 0), + Limit: api.ParseIntQueryParam(c.Query("limit"), 0), + } + + queryResult := api.QueryResponse{ + Meta: api.Meta{ + ChainId: chainId.Uint64(), + Page: qf.Page, + Limit: qf.Limit, + }, + } + + mainStorage, err = getMainStorage() + if err != nil { + log.Error().Err(err).Msg("Error getting main storage") + api.InternalErrorHandler(c) + return + } + + balancesResult, err := mainStorage.GetTokenBalances(qf, columns...) + if err != nil { + log.Error().Err(err).Msg("Error querying balances") + // TODO: might want to choose BadRequestError if it's due to not-allowed functions + api.InternalErrorHandler(c) + return + } + queryResult.Data = serializeHolders(balancesResult.Data) + sendJSONResponse(c, queryResult) +} + +func serializeHolders(holders []common.TokenBalance) []HolderModel { + holderModels := make([]HolderModel, len(holders)) + for i, holder := range holders { + holderModels[i] = serializeHolder(holder) + } + return holderModels +} + +func serializeHolder(holder common.TokenBalance) HolderModel { + return HolderModel{ + HolderAddress: holder.Owner, + Balance: holder.Balance, + TokenId: func() string { + if holder.TokenId != nil { + return holder.TokenId.String() + } + return "" + }(), + } +} diff --git a/internal/storage/clickhouse.go b/internal/storage/clickhouse.go index 647891b..561ba2c 100644 --- a/internal/storage/clickhouse.go +++ b/internal/storage/clickhouse.go @@ -1378,8 +1378,14 @@ func (c *ClickHouseConnector) GetTokenBalances(qf BalancesQueryFilter, fields .. if len(fields) > 0 { columns = strings.Join(fields, ", ") } - query := fmt.Sprintf("SELECT %s FROM %s.token_balances WHERE chain_id = ? AND token_type = ? AND owner = ?", columns, c.cfg.Database) + query := fmt.Sprintf("SELECT %s FROM %s.token_balances WHERE chain_id = ?", columns, c.cfg.Database) + if qf.TokenType != "" { + query += fmt.Sprintf(" AND token_type = '%s'", qf.TokenType) + } + if qf.Owner != "" { + query += fmt.Sprintf(" AND owner = '%s'", qf.Owner) + } if qf.TokenAddress != "" { query += fmt.Sprintf(" AND address = '%s'", qf.TokenAddress) } @@ -1420,7 +1426,7 @@ func (c *ClickHouseConnector) GetTokenBalances(qf BalancesQueryFilter, fields .. query += fmt.Sprintf(" LIMIT %d", qf.Limit) } - rows, err := c.conn.Query(context.Background(), query, qf.ChainId, qf.TokenType, qf.Owner) + rows, err := c.conn.Query(context.Background(), query, qf.ChainId) if err != nil { return QueryResult[common.TokenBalance]{}, err }