diff --git a/cmd/api.go b/cmd/api.go index b054644..cdd24db 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -87,6 +87,8 @@ func RunApi(cmd *cobra.Command, args []string) { // token balance queries root.GET("/balances/:owner/:type", handlers.GetTokenBalancesByType) + root.GET("/balances/:owner", handlers.GetTokenBalancesByType) + // token holder queries root.GET("/holders/:address", handlers.GetTokenHoldersByType) diff --git a/internal/handlers/token_handlers.go b/internal/handlers/token_handlers.go index 1fe38ac..51c075d 100644 --- a/internal/handlers/token_handlers.go +++ b/internal/handlers/token_handlers.go @@ -2,6 +2,7 @@ package handlers import ( "fmt" + "math/big" "strings" "github.com/gin-gonic/gin" @@ -47,11 +48,13 @@ func GetTokenBalancesByType(c *gin.Context) { api.BadRequestErrorHandler(c, err) return } - tokenType := c.Param("type") - if tokenType != "erc20" && tokenType != "erc1155" && tokenType != "erc721" { - api.BadRequestErrorHandler(c, fmt.Errorf("invalid token type '%s'", tokenType)) + + tokenTypes, err := getTokenTypesFromReq(c) + if err != nil { + api.BadRequestErrorHandler(c, err) return } + owner := strings.ToLower(c.Param("owner")) if !strings.HasPrefix(owner, "0x") { api.BadRequestErrorHandler(c, fmt.Errorf("invalid owner address '%s'", owner)) @@ -62,11 +65,18 @@ func GetTokenBalancesByType(c *gin.Context) { api.BadRequestErrorHandler(c, fmt.Errorf("invalid token address '%s'", tokenAddress)) return } + + tokenIds, err := getTokenIdsFromReq(c) + if err != nil { + api.BadRequestErrorHandler(c, fmt.Errorf("invalid token ids '%s'", err)) + return + } + hideZeroBalances := c.Query("hide_zero_balances") != "false" columns := []string{"address", "sum(balance) as balance"} groupBy := []string{"address"} - if tokenType != "erc20" { + if !strings.Contains(strings.Join(tokenTypes, ","), "erc20") { columns = []string{"address", "token_id", "sum(balance) as balance"} groupBy = []string{"address", "token_id"} } @@ -74,9 +84,10 @@ func GetTokenBalancesByType(c *gin.Context) { qf := storage.BalancesQueryFilter{ ChainId: chainId, Owner: owner, - TokenType: tokenType, + TokenTypes: tokenTypes, TokenAddress: tokenAddress, ZeroBalance: hideZeroBalances, + TokenIds: tokenIds, GroupBy: groupBy, SortBy: c.Query("sort_by"), SortOrder: c.Query("sort_order"), @@ -131,6 +142,43 @@ func serializeBalance(balance common.TokenBalance) BalanceModel { } } +func getTokenTypesFromReq(c *gin.Context) ([]string, error) { + tokenTypeParam := c.Param("type") + var tokenTypes []string + if tokenTypeParam != "" { + tokenTypes = []string{tokenTypeParam} + } else { + tokenTypes = c.QueryArray("token_type") + } + + for i, tokenType := range tokenTypes { + tokenType = strings.ToLower(tokenType) + if tokenType != "erc721" && tokenType != "erc1155" && tokenType != "erc20" { + return []string{}, fmt.Errorf("invalid token type: %s", tokenType) + } + tokenTypes[i] = tokenType + } + return tokenTypes, nil +} + +func getTokenIdsFromReq(c *gin.Context) ([]*big.Int, error) { + tokenIds := c.QueryArray("token_id") + tokenIdsBn := make([]*big.Int, len(tokenIds)) + for i, tokenId := range tokenIds { + tokenId = strings.TrimSpace(tokenId) // Remove potential whitespace + if tokenId == "" { + return nil, fmt.Errorf("invalid token id: %s", tokenId) + } + num := new(big.Int) + _, ok := num.SetString(tokenId, 10) // Base 10 + if !ok { + return nil, fmt.Errorf("invalid token id: %s", tokenId) + } + tokenIdsBn[i] = num + } + return tokenIdsBn, nil +} + // @Summary Get holders of a token // @Description Retrieve holders of a token // @Tags holders @@ -161,25 +209,32 @@ func GetTokenHoldersByType(c *gin.Context) { return } - tokenType := c.Query("token_type") - if tokenType != "" && tokenType != "erc20" && tokenType != "erc1155" && tokenType != "erc721" { - api.BadRequestErrorHandler(c, fmt.Errorf("invalid token type '%s'", tokenType)) + tokenTypes, err := getTokenTypesFromReq(c) + if err != nil { + api.BadRequestErrorHandler(c, err) return } hideZeroBalances := c.Query("hide_zero_balances") != "false" columns := []string{"owner", "sum(balance) as balance"} groupBy := []string{"owner"} - if tokenType != "erc20" { + + if !strings.Contains(strings.Join(tokenTypes, ","), "erc20") { columns = []string{"owner", "token_id", "sum(balance) as balance"} groupBy = []string{"owner", "token_id"} } + tokenIds, err := getTokenIdsFromReq(c) + if err != nil { + api.BadRequestErrorHandler(c, fmt.Errorf("invalid token ids '%s'", err)) + return + } qf := storage.BalancesQueryFilter{ ChainId: chainId, - TokenType: tokenType, + TokenTypes: tokenTypes, TokenAddress: address, ZeroBalance: hideZeroBalances, + TokenIds: tokenIds, GroupBy: groupBy, SortBy: c.Query("sort_by"), SortOrder: c.Query("sort_order"), diff --git a/internal/storage/clickhouse.go b/internal/storage/clickhouse.go index dccc38e..a48f5c6 100644 --- a/internal/storage/clickhouse.go +++ b/internal/storage/clickhouse.go @@ -1399,9 +1399,16 @@ func (c *ClickHouseConnector) GetTokenBalances(qf BalancesQueryFilter, fields .. } query := fmt.Sprintf("SELECT %s FROM %s.token_balances WHERE chain_id = ?", columns, c.cfg.Database) - if qf.TokenType != "" { - query += fmt.Sprintf(" AND token_type = '%s'", qf.TokenType) + if len(qf.TokenTypes) > 0 { + tokenTypesStr := "" + tokenTypesLen := len(qf.TokenTypes) + for i := 0; i < tokenTypesLen-1; i++ { + tokenTypesStr += fmt.Sprintf("'%s',", qf.TokenTypes[i]) + } + tokenTypesStr += fmt.Sprintf("'%s'", qf.TokenTypes[tokenTypesLen-1]) + query += fmt.Sprintf(" AND token_type in (%s)", tokenTypesStr) } + if qf.Owner != "" { query += fmt.Sprintf(" AND owner = '%s'", qf.Owner) } @@ -1409,6 +1416,16 @@ func (c *ClickHouseConnector) GetTokenBalances(qf BalancesQueryFilter, fields .. query += fmt.Sprintf(" AND address = '%s'", qf.TokenAddress) } + if len(qf.TokenIds) > 0 { + tokenIdsStr := "" + tokenIdsLen := len(qf.TokenIds) + for i := 0; i < tokenIdsLen-1; i++ { + tokenIdsStr += fmt.Sprintf("%s,", qf.TokenIds[i].String()) + } + tokenIdsStr += qf.TokenIds[tokenIdsLen-1].String() + query += fmt.Sprintf(" AND token_id in (%s)", tokenIdsStr) + } + isBalanceAggregated := false for _, field := range fields { if strings.Contains(field, "balance") && strings.TrimSpace(field) != "balance" { diff --git a/internal/storage/connector.go b/internal/storage/connector.go index dae4330..9a564ac 100644 --- a/internal/storage/connector.go +++ b/internal/storage/connector.go @@ -27,9 +27,10 @@ type QueryFilter struct { type BalancesQueryFilter struct { ChainId *big.Int - TokenType string + TokenTypes []string TokenAddress string Owner string + TokenIds []*big.Int ZeroBalance bool GroupBy []string SortBy string