diff --git a/.env b/.env index 24745922420..e15ba740c48 100644 --- a/.env +++ b/.env @@ -34,6 +34,14 @@ SHELLHUB_PROXY=false # Enable automatic HTTPS with Let's Encrypt. SHELLHUB_AUTO_SSL=false +SHELLHUB_DATABASE=mongo + +SHELLHUB_POSTGRES_HOST=postgres +SHELLHUB_POSTGRES_PORT=5432 +SHELLHUB_POSTGRES_USERNAME=admin +SHELLHUB_POSTGRES_PASSWORD=admin +SHELLHUB_POSTGRES_DATABASE=main + # The domain of the server. # NOTICE: Required only if automatic HTTPS is enabled. # VALUES: A valid domain name diff --git a/api/go.mod b/api/go.mod index a20ba1c1665..9ce2d304a15 100644 --- a/api/go.mod +++ b/api/go.mod @@ -8,9 +8,11 @@ require ( github.com/getsentry/sentry-go v0.36.2 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/gorilla/websocket v1.5.3 + github.com/jackc/pgx/v5 v5.7.6 github.com/labstack/echo-contrib v0.17.4 github.com/labstack/echo/v4 v4.13.4 github.com/labstack/gommon v0.4.2 + github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11 github.com/pkg/errors v0.9.1 github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742 github.com/shellhub-io/shellhub v0.13.4 @@ -18,7 +20,11 @@ require ( github.com/spf13/cobra v1.10.1 github.com/square/mongo-lock v0.0.0-20230808145049-cfcf499f6bf0 github.com/stretchr/testify v1.11.1 + github.com/testcontainers/testcontainers-go v0.40.0 github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0 + github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 + github.com/uptrace/bun v1.2.15 + github.com/uptrace/bun/dialect/pgdialect v1.2.15 github.com/xakep666/mongo-migrate v0.3.2 go.mongodb.org/mongo-driver v1.17.6 golang.org/x/crypto v0.43.0 @@ -73,6 +79,10 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hibiken/asynq v0.24.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/pgzip v1.2.5 // indirect @@ -109,6 +119,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.63.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/redis/go-redis/v9 v9.0.3 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sethvargo/go-envconfig v0.9.0 // indirect @@ -116,17 +127,17 @@ require ( github.com/spf13/cast v1.3.1 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/numcpus v0.7.0 // indirect github.com/tkuchiki/go-timezone v0.2.2 // indirect github.com/tkuchiki/parsetime v0.3.0 // indirect + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/ulikunitz/xz v0.5.14 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect - github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/woodsbury/decimal128 v1.3.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect @@ -142,7 +153,7 @@ require ( go4.org v0.0.0-20200411211856-f5505b9728dd // indirect golang.org/x/net v0.45.0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.37.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.30.0 // indirect golang.org/x/time v0.12.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/api/go.sum b/api/go.sum index e2569d83273..4c1ff55291b 100644 --- a/api/go.sum +++ b/api/go.sum @@ -210,6 +210,16 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -240,6 +250,8 @@ github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0 github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/leodido/go-urn v1.2.2 h1:7z68G0FCGvDk646jz1AelTYNYWrTNm0bEcFAo147wt4= github.com/leodido/go-urn v1.2.2/go.mod h1:kUaIbLZWttglzwNuG0pgsh5vuV6u2YcGBYz1hIPjtOQ= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74 h1:1KuuSOy4ZNgW0KA2oYIngXVFhQcXxhLqCVK7cBcldkk= github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= @@ -252,6 +264,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/goveralls v0.0.9/go.mod h1:FRbM1PS8oVsOe9JtdzAAXM+DsvDMMHcM1C7drGJD8HY= +github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= +github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mholt/archiver/v4 v4.0.0-alpha.8 h1:tRGQuDVPh66WCOelqe6LIGh0gwmfwxUrSSDunscGsRM= github.com/mholt/archiver/v4 v4.0.0-alpha.8/go.mod h1:5f7FUYGXdJWUjESffJaYR4R60VhnHxb2X3T1teMyv5A= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -288,6 +302,10 @@ github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//J github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= +github.com/oiime/logrusbun v0.1.1 h1:o3aK0PGErb1G0JC43yAIhoGxSbgtYRHhlyTtq6o1rag= +github.com/oiime/logrusbun v0.1.1/go.mod h1:HH9akx9teKgQPX41TYpLLRNxaL8q9R+ltzABnwUHfBM= +github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11 h1:rAqW9sGcM0VsfBwgeBzHk0yebrRwfeSJFy9Egqi0fmM= +github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11/go.mod h1:HH9akx9teKgQPX41TYpLLRNxaL8q9R+ltzABnwUHfBM= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= @@ -327,6 +345,8 @@ github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= +github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -345,6 +365,7 @@ github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742 h1:sIFW1zdZv github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742/go.mod h1:6J6yfW5oIvAZ6VjxmV9KyFZyPFVM3B4V3Epbb+1c0oo= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= @@ -361,6 +382,7 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -374,6 +396,8 @@ github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0 h1:z/1qHeliTLDKNaJ7uOHOx1FjwghbcbYfga4dTFkF0hU= github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0/go.mod h1:GaunAWwMXLtsMKG3xn2HYIBDbKddGArfcGsF2Aog81E= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ= github.com/testcontainers/testcontainers-go/modules/redis v0.32.0 h1:HW5Qo9qfLi5iwfS7cbXwG6qe8ybXGePcgGPEmVlVDlo= github.com/testcontainers/testcontainers-go/modules/redis v0.32.0/go.mod h1:5kltdxVKZG0aP1iegeqKz4K8HHyP0wbkW5o84qLyMjY= github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw= @@ -387,11 +411,18 @@ github.com/tkuchiki/go-timezone v0.2.2 h1:MdHR65KwgVTwWFQrota4SKzc4L5EfuH5SdZZGt github.com/tkuchiki/go-timezone v0.2.2/go.mod h1:oFweWxYl35C/s7HMVZXiA19Jr9Y0qJHMaG/J2TES4LY= github.com/tkuchiki/parsetime v0.3.0 h1:cvblFQlPeAPJL8g6MgIGCHnnmHSZvluuY+hexoZCNqc= github.com/tkuchiki/parsetime v0.3.0/go.mod h1:OJkQmIrf5Ao7R+WYIdITPOfDVj8LmnHGCfQ8DTs3LCA= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.14 h1:uv/0Bq533iFdnMHZdRBTOlaNMdb1+ZxXIlHDZHIHcvg= github.com/ulikunitz/xz v0.5.14/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/uptrace/bun v0.3.9/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= +github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE= +github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038= +github.com/uptrace/bun/dialect/pgdialect v1.2.15 h1:er+/3giAIqpfrXJw+KP9B7ujyQIi5XkPnFmgjAVL6bA= +github.com/uptrace/bun/dialect/pgdialect v1.2.15/go.mod h1:QSiz6Qpy9wlGFsfpf7UMSL6mXAL1jDJhFwuOVacCnOQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= @@ -399,8 +430,8 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/vmihailenco/go-tinylfu v0.2.2 h1:H1eiG6HM36iniK6+21n9LLpzx1G9R3DJa2UjUjbynsI= github.com/vmihailenco/go-tinylfu v0.2.2/go.mod h1:CutYi2Q9puTxfcolkliPq4npPuofg9N9t8JVrjzwa3Q= github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= -github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= @@ -572,6 +603,8 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/api/server.go b/api/server.go index 193184551db..4a11ef7cc27 100644 --- a/api/server.go +++ b/api/server.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "os" "strings" @@ -10,8 +11,11 @@ import ( "github.com/shellhub-io/shellhub/api/routes" "github.com/shellhub-io/shellhub/api/routes/middleware" "github.com/shellhub-io/shellhub/api/services" + "github.com/shellhub-io/shellhub/api/store" "github.com/shellhub-io/shellhub/api/store/mongo" - "github.com/shellhub-io/shellhub/api/store/mongo/options" + mongooptions "github.com/shellhub-io/shellhub/api/store/mongo/options" + "github.com/shellhub-io/shellhub/api/store/pg" + pgoptions "github.com/shellhub-io/shellhub/api/store/pg/options" "github.com/shellhub-io/shellhub/pkg/api/internalclient" "github.com/shellhub-io/shellhub/pkg/cache" "github.com/shellhub-io/shellhub/pkg/envs" @@ -22,6 +26,19 @@ import ( ) type env struct { + Database string `env:"DATABASE,default=mongo"` + + // PostgresHost specifies the host for PostgreSQL. + PostgresHost string `env:"POSTGRES_HOST,default=postgres"` + // PostgresPort specifies the port for PostgreSQL. + PostgresPort string `env:"POSTGRES_PORT,default=5432"` + // PostgresUsername specifies the username for authenticate PostgreSQL. + PostgresUsername string `env:"POSTGRES_USERNAME,default=admin"` + // PostgresUser specifies the password for authenticate PostgreSQL. + PostgresPassword string `env:"POSTGRES_PASSWORD,default=admin"` + // PostgresDatabase especifica o nome do banco de dados PostgreSQL a ser utilizado. + PostgresDatabase string `env:"POSTGRES_DATABASE,default=main"` + // MongoURI specifies the connection string for MongoDB. MongoURI string `env:"MONGO_URI,default=mongodb://mongo:27017/main"` @@ -78,7 +95,18 @@ func (s *Server) Setup(ctx context.Context) error { log.Debug("Redis cache initialized successfully") - store, err := mongo.NewStore(ctx, s.env.MongoURI, cache, options.RunMigatrions) + var store store.Store + switch s.env.Database { + case "mongo": + store, err = mongo.NewStore(ctx, s.env.MongoURI, cache, mongooptions.RunMigatrions) + case "postgres": + uri := pg.URI(s.env.PostgresHost, s.env.PostgresPort, s.env.PostgresUsername, s.env.PostgresPassword, s.env.PostgresDatabase) + store, err = pg.New(ctx, uri, pgoptions.Log("INFO", true), pgoptions.Migrate()) // TODO: Log envs + default: + log.WithField("database", s.env.Database).Error("invalid database") + + return errors.New("invalid database") + } if err != nil { log. WithError(err). diff --git a/api/store/errors.go b/api/store/errors.go index a2c82597277..dfcfa8d79ed 100644 --- a/api/store/errors.go +++ b/api/store/errors.go @@ -13,9 +13,10 @@ const ( ) var ( - ErrDuplicate = errors.New("document duplicate", ErrLayer, ErrCodeDuplicated) - ErrNoDocuments = errors.New("no documents", ErrLayer, ErrCodeNoDocument) - ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID", ErrLayer, ErrCodeInvalid) + ErrDuplicate = errors.New("document duplicate", ErrLayer, ErrCodeDuplicated) + ErrNoDocuments = errors.New("no documents", ErrLayer, ErrCodeNoDocument) + ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID", ErrLayer, ErrCodeInvalid) + ErrResolverNotFound = errors.New("resolver not found", ErrLayer, ErrCodeInvalid) ) // Errors used by Cloud. diff --git a/api/store/pg/api-key.go b/api/store/pg/api-key.go new file mode 100644 index 00000000000..d5d17dda422 --- /dev/null +++ b/api/store/pg/api-key.go @@ -0,0 +1,115 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +func (pg *Pg) APIKeyCreate(ctx context.Context, apiKey *models.APIKey) (string, error) { + db := pg.getConnection(ctx) + + apiKey.CreatedAt = clock.Now() + apiKey.UpdatedAt = clock.Now() + if _, err := db.NewInsert().Model(entity.APIKeyFromModel(apiKey)).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return apiKey.ID, nil +} + +func (pg *Pg) APIKeyConflicts(ctx context.Context, tenantID string, target *models.APIKeyConflicts) ([]string, bool, error) { + db := pg.getConnection(ctx) + + apiKeys := make([]map[string]any, 0) + if err := db.NewSelect().Model((*entity.Namespace)(nil)).Column("name").Where("name = ?", target.Name).Scan(ctx, &apiKeys); err != nil { + return nil, false, fromSQLError(err) + } + + conflicts := make([]string, 0) + for _, apiKey := range apiKeys { + if apiKey["name"] == target.Name { + conflicts = append(conflicts, "name") + } + } + + return conflicts, len(conflicts) > 0, nil +} + +func (pg *Pg) APIKeyList(ctx context.Context, opts ...store.QueryOption) ([]models.APIKey, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.APIKey, 0) + + query := db.NewSelect().Model(&entities) + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + apiKeys := make([]models.APIKey, len(entities)) + for i, e := range entities { + apiKeys[i] = *entity.APIKeyToModel(&e) + } + + return apiKeys, count, nil +} + +func (pg *Pg) APIKeyResolve(ctx context.Context, resolver store.APIKeyResolver, val string, opts ...store.QueryOption) (*models.APIKey, error) { + db := pg.getConnection(ctx) + + column, err := APIKeyResolverToString(resolver) + if err != nil { + return nil, err + } + + apKey := new(entity.APIKey) + query := db.NewSelect().Model(apKey).Where("? = ?", bun.Ident(column), val) + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.APIKeyToModel(apKey), nil +} + +func (pg *Pg) APIKeyUpdate(ctx context.Context, apiKey *models.APIKey) error { + db := pg.getConnection(ctx) + + a := entity.APIKeyFromModel(apiKey) + a.UpdatedAt = clock.Now() + _, err := db.NewUpdate().Model(a).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func (pg *Pg) APIKeyDelete(ctx context.Context, apiKey *models.APIKey) error { + db := pg.getConnection(ctx) + + a := entity.APIKeyFromModel(apiKey) + _, err := db.NewDelete().Model(a).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func APIKeyResolverToString(resolver store.APIKeyResolver) (string, error) { + switch resolver { + case store.APIKeyIDResolver: + return "id", nil + case store.APIKeyNameResolver: + return "name", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/api-key_test.go b/api/store/pg/api-key_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/api-key_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/dbtest/dbtest.go b/api/store/pg/dbtest/dbtest.go new file mode 100644 index 00000000000..4afb5a57035 --- /dev/null +++ b/api/store/pg/dbtest/dbtest.go @@ -0,0 +1,53 @@ +package dbtest + +import ( + "context" + "time" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" +) + +// Server represents a Postgres test server instance. +type Server struct { + container *postgres.PostgresContainer +} + +// Up starts a new Postgres container. Use [Server.ConnectionString] to access the connection string. +func (srv *Server) Up(ctx context.Context) error { + opts := []testcontainers.ContainerCustomizer{ + postgres.WithDatabase("test"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy(wait.ForLog("database system is ready to accept connections").WithOccurrence(2).WithStartupTimeout(60 * time.Second)), + } + + container, err := postgres.Run(ctx, "postgres:18.0", opts...) + if err != nil { + return err + } + + srv.container = container + + return nil +} + +// Down gracefully terminates the Postgres container. +func (srv *Server) Down(ctx context.Context) error { + return srv.container.Terminate(ctx) +} + +func (srv *Server) ConnectionString(ctx context.Context) (string, error) { + host, err := srv.container.Host(ctx) + if err != nil { + return "", err + } + + port, err := srv.container.MappedPort(ctx, "5432") + if err != nil { + return "", err + } + + return "postgres://postgres:postgres@" + host + ":" + port.Port() + "/test?sslmode=disable", nil +} diff --git a/api/store/pg/dbtest/fixtures.go b/api/store/pg/dbtest/fixtures.go new file mode 100644 index 00000000000..801d04b2141 --- /dev/null +++ b/api/store/pg/dbtest/fixtures.go @@ -0,0 +1,12 @@ +package dbtest + +import ( + "path/filepath" + "runtime" +) + +func FixturesPath() string { + _, file, _, _ := runtime.Caller(0) + + return filepath.Join(filepath.Dir(file), "fixtures") +} diff --git a/api/store/pg/dbtest/fixtures/api-keys.yml b/api/store/pg/dbtest/fixtures/api-keys.yml new file mode 100644 index 00000000000..7b1ee31ca85 --- /dev/null +++ b/api/store/pg/dbtest/fixtures/api-keys.yml @@ -0,0 +1,18 @@ +- model: APIKey + rows: + - id: f23a2e56cd3fcfba002c72675c870e1e7813292adc40bbf14cea479a2e07976a + name: dev + created_by: 507f1f77bcf86cd799439011 + tenant_id: 00000000-0000-4000-0000-000000000000 + role: admin + created_at: '2023-01-01T12:00:00.000Z' + updated_at: '2023-01-01T12:00:00.000Z' + expires_in: 0 + - id: a1b2c73ea41f70870c035283336d72228118213ed03ec78043ffee48d827af11 + name: prod + created_by: 507f1f77bcf86cd799439011 + tenant_id: 00000000-0000-4000-0000-000000000000 + role: operator + created_at: '2023-01-02T12:00:00.000Z' + updated_at: '2023-01-02T12:00:00.000Z' + expires_in: 10 diff --git a/api/store/pg/dbtest/fixtures/users.yml b/api/store/pg/dbtest/fixtures/users.yml new file mode 100644 index 00000000000..c69ff3bc5b6 --- /dev/null +++ b/api/store/pg/dbtest/fixtures/users.yml @@ -0,0 +1,18 @@ +- model: User + rows: + - id: 0195cefa-aa01-7efb-8098-c9c173056250 + created_at: 2025-01-15T10:30:00+00:00 + updated_at: 2025-01-15T10:30:00+00:00 + last_login: null + status: confirmed + origin: local + external_id: "" + name: Jonh Doe + username: john_doe + email: john.doe@test.com + security_email: jane.smith@test.com + password_digest: "$2y$12$VVm2ETx7AvaGlfMYqNYK9uzU2M45YZ70YnT..O.s1o2zdE1pekhq6" + auth_methods: [ local ] + namespace_ownership_limit: -1 + email_marketing: true + preferred_namespace_id: null diff --git a/api/store/pg/device.go b/api/store/pg/device.go new file mode 100644 index 00000000000..4c3b60eaa5d --- /dev/null +++ b/api/store/pg/device.go @@ -0,0 +1,238 @@ +package pg + +import ( + "context" //nolint:gosec + "time" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +func (pg *Pg) DeviceCreate(ctx context.Context, device *models.Device) (string, error) { + db := pg.getConnection(ctx) + + device.CreatedAt = clock.Now() + + e := entity.DeviceFromModel(device) + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return e.ID, nil +} + +func (pg *Pg) DeviceConflicts(ctx context.Context, target *models.DeviceConflicts) ([]string, bool, error) { + db := pg.getConnection(ctx) + + devices := make([]map[string]any, 0) + if err := db.NewSelect().Model((*entity.Device)(nil)).Column("name").Where("name = ?", target.Name).Scan(ctx, &devices); err != nil { + return nil, false, fromSQLError(err) + } + + conflicts := make([]string, 0) + for _, device := range devices { + if device["name"] == target.Name { + conflicts = append(conflicts, "name") + } + } + + return conflicts, len(conflicts) > 0, nil +} + +func (pg *Pg) DeviceList(ctx context.Context, acceptable store.DeviceAcceptable, opts ...store.QueryOption) ([]models.Device, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.Device, 0) + + query := db. + NewSelect(). + Model(&entities). + Column("device.*"). + Relation("Namespace"). + ColumnExpr(string(deviceExprOnline), time.Now().Add(-2*time.Minute)). + ColumnExpr(deviceExprAcepptable(acceptable)) + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + devices := make([]models.Device, len(entities)) + for i, e := range entities { + devices[i] = *entity.DeviceToModel(&e) + } + + return devices, count, nil +} + +func (pg *Pg) DeviceResolve(ctx context.Context, resolver store.DeviceResolver, val string, opts ...store.QueryOption) (*models.Device, error) { + db := pg.getConnection(ctx) + + column, err := DeviceResolverToString(resolver) + if err != nil { + return nil, err + } + + d := new(entity.Device) + + query := db. + NewSelect(). + Model(d). + Where("? = ?", bun.Ident("device."+column), val). + Column("device.*"). + Relation("Namespace"). + ColumnExpr(string(deviceExprOnline), time.Now().Add(-2*time.Minute)) + + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.DeviceToModel(d), nil +} + +func (pg *Pg) DeviceUpdate(ctx context.Context, device *models.Device) error { + db := pg.getConnection(ctx) + + d := entity.DeviceFromModel(device) + d.UpdatedAt = clock.Now() + _, err := db.NewUpdate().Model(d).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func (pg *Pg) DeviceHeartbeat(ctx context.Context, ids []string, lastSeen time.Time) (int64, error) { + db := pg.getConnection(ctx) + + r, err := db.NewUpdate(). + Model((*entity.Device)(nil)). + Set("seen_at = ?", lastSeen). + Set("disconnected_at = NULL"). + TableExpr("(SELECT unnest(?::varchar[]) as id) as _data", pgdialect.Array(ids)). + Where("device.id = _data.id"). + Exec(ctx) + if err != nil { + return 0, fromSQLError(err) + } + + return r.RowsAffected() +} + +func (pg *Pg) DeviceDelete(ctx context.Context, device *models.Device) error { + deletedCount, err := pg.DeviceDeleteMany(ctx, []string{device.UID}) + switch { + case err != nil: + return err + case deletedCount < 1: + return store.ErrNoDocuments + default: + return nil + } +} + +func (pg *Pg) DeviceDeleteMany(ctx context.Context, uids []string) (int64, error) { + db := pg.getConnection(ctx) + fn := pg.deviceDeleteManyFn(ctx, uids) + + if tx, ok := db.(bun.Tx); ok { + return fn(tx) + } else { // nolint:revive + tx, err := pg.driver.BeginTx(ctx, nil) + if err != nil { + return 0, fromSQLError(err) + } + + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) + } + }() + + count, err := fn(tx) + if err != nil { + _ = tx.Rollback() + + return 0, err + } + + if err := tx.Commit(); err != nil { + return 0, fromSQLError(err) + } + + return count, nil + } +} + +func (pg *Pg) deviceDeleteManyFn(ctx context.Context, uids []string) func(tx bun.Tx) (int64, error) { + return func(tx bun.Tx) (int64, error) { + r, err := tx.NewDelete().Model((*entity.Device)(nil)).Where("id IN (?)", bun.In(uids)).Exec(ctx) + if err != nil { + return 0, fromSQLError(err) + } + + count, _ := r.RowsAffected() + + // if _, err := tx.NewDelete(). + // Model((*entity.Session)(nil)). + // Where("device_uid IN (?)", bun.In(uids)). + // Exec(ctx); err != nil { + // return 0, fromSQLError(err) + // } + // + // if _, err := tx.NewDelete(). + // Model((*entity.Tunnel)(nil)). + // Where("device IN (?)", bun.In(uids)). + // Exec(ctx); err != nil { + // return 0, fromSQLError(err) + // } + + return count, nil + } +} + +type deviceExpr string + +const ( + deviceExprOnline deviceExpr = ` + CASE + WHEN "device"."disconnected_at" IS NULL AND "device"."seen_at" > ? + THEN true + ELSE false + END AS "online"` +) + +// deviceExprAcepptable generates the SQL expression for the "acceptable" field +// based on the provided store.DeviceAcceptable mode. +func deviceExprAcepptable(mode store.DeviceAcceptable) string { + switch mode { + case store.DeviceAcceptableFromRemoved: + return `"device"."status" = 'removed' AS "acceptable"` + case store.DeviceAcceptableAsFalse: + return `false AS "acceptable"` + case store.DeviceAcceptableIfNotAccepted: + return `CASE WHEN "device"."status" <> 'accepted' THEN true ELSE false END AS "acceptable"` + default: + return `true AS "acceptable"` + } +} + +func DeviceResolverToString(resolver store.DeviceResolver) (string, error) { + switch resolver { + case store.DeviceUIDResolver: + return "id", nil + case store.DeviceHostnameResolver: + return "name", nil + case store.DeviceMACResolver: + return "mac", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/device_test.go b/api/store/pg/device_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/device_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/entity/api-key.go b/api/store/pg/entity/api-key.go new file mode 100644 index 00000000000..5ff3f1677bf --- /dev/null +++ b/api/store/pg/entity/api-key.go @@ -0,0 +1,48 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/api/authorizer" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type APIKey struct { + bun.BaseModel `bun:"table:api_keys"` + + KeyDigest string `bun:"key_digest,pk"` + NamespaceID string `bun:"namespace_id,pk"` + Name string `bun:"name"` + Role string `bun:"role"` + UserID string `bun:"user_id"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + ExpiresIn int64 `bun:"expires_in,nullzero"` +} + +func APIKeyFromModel(model *models.APIKey) *APIKey { + return &APIKey{ + Name: model.Name, + NamespaceID: model.TenantID, + KeyDigest: model.ID, + Role: model.Role.String(), + UserID: model.CreatedBy, + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + ExpiresIn: model.ExpiresIn, + } +} + +func APIKeyToModel(entity *APIKey) *models.APIKey { + return &models.APIKey{ + ID: entity.KeyDigest, + Name: entity.Name, + TenantID: entity.NamespaceID, + Role: authorizer.Role(entity.Role), + CreatedBy: entity.UserID, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + ExpiresIn: entity.ExpiresIn, + } +} diff --git a/api/store/pg/entity/device.go b/api/store/pg/entity/device.go new file mode 100644 index 00000000000..02620987751 --- /dev/null +++ b/api/store/pg/entity/device.go @@ -0,0 +1,128 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Device struct { + bun.BaseModel `bun:"table:devices"` + + ID string `bun:"id,pk"` + NamespaceID string `bun:"namespace_id,pk,type:uuid"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + RemovedAt *time.Time `bun:"removed_at"` + SeenAt time.Time `bun:"seen_at"` + DisconnectedAt time.Time `bun:"disconnected_at,nullzero"` + Online bool `bun:",scanonly"` + Acceptable bool `bun:",scanonly"` + Status string `bun:"status"` + Name string `bun:"name"` + MAC string `bun:"mac"` + PublicKey string `bun:"public_key"` + Identifier string `bun:"identifier"` + PrettyName string `bun:"pretty_name"` + Version string `bun:"version"` + Arch string `bun:"arch"` + Platform string `bun:"platform"` + Longitude float64 `bun:"longitude,type:numeric"` + Latitude float64 `bun:"latitude,type:numeric"` + + Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"` + Tags []*Tag `bun:"m2m:device_tags,join:Device=Tag"` +} + +func DeviceFromModel(model *models.Device) *Device { + device := &Device{ + ID: model.UID, + NamespaceID: model.TenantID, + CreatedAt: model.CreatedAt, + UpdatedAt: time.Time{}, + SeenAt: model.LastSeen, + Status: string(model.Status), + Name: model.Name, + PublicKey: model.PublicKey, + Tags: []*Tag{}, + } + + if model.DisconnectedAt != nil { + device.DisconnectedAt = *model.DisconnectedAt + } + + if model.Identity != nil { + device.MAC = model.Identity.MAC + } + + if model.Position != nil { + device.Longitude = model.Position.Longitude + device.Latitude = model.Position.Latitude + } + + if model.Info != nil { + device.Identifier = model.Info.ID + device.PrettyName = model.Info.PrettyName + device.Version = model.Info.Version + device.Arch = model.Info.Arch + device.Platform = model.Info.Platform + } + + if len(model.Tags) > 0 { + device.Tags = make([]*Tag, len(model.Tags)) + for i, t := range model.Tags { + device.Tags[i] = TagFromModel(&t) + } + } + + return device +} + +func DeviceToModel(entity *Device) *models.Device { + device := &models.Device{ + UID: entity.ID, + TenantID: entity.NamespaceID, + CreatedAt: entity.CreatedAt, + LastSeen: entity.SeenAt, + Status: models.DeviceStatus(entity.Status), + Name: entity.Name, + PublicKey: entity.PublicKey, + Online: entity.Online, + Acceptable: entity.Acceptable, + Namespace: entity.Namespace.Name, + DisconnectedAt: nil, + RemoteAddr: "", + Taggable: models.Taggable{ + Tags: []models.Tag{}, + }, + Position: &models.DevicePosition{ + Longitude: entity.Longitude, + Latitude: entity.Latitude, + }, + Info: &models.DeviceInfo{ + ID: entity.Identifier, + PrettyName: entity.PrettyName, + Version: entity.Version, + Arch: entity.Arch, + Platform: entity.Platform, + }, + Identity: &models.DeviceIdentity{ + MAC: entity.MAC, + }, + } + + if !entity.DisconnectedAt.IsZero() { + disconnectedAt := entity.DisconnectedAt + device.DisconnectedAt = &disconnectedAt + } + + if len(entity.Tags) > 0 { + device.Tags = make([]models.Tag, len(entity.Tags)) + for i, t := range entity.Tags { + device.Tags[i] = *TagToModel(t) + } + } + + return device +} diff --git a/api/store/pg/entity/entity.go b/api/store/pg/entity/entity.go new file mode 100644 index 00000000000..ce625a48ff8 --- /dev/null +++ b/api/store/pg/entity/entity.go @@ -0,0 +1,22 @@ +package entity + +func Entities() []any { + return []any{ + // Register intermediary models first for many-to-many relationships + (*DeviceTag)(nil), + (*PublicKeyTag)(nil), + + (*APIKey)(nil), + (*Device)(nil), + (*Membership)(nil), + (*Namespace)(nil), + (*PrivateKey)(nil), + (*PublicKey)(nil), + (*Session)(nil), + (*ActiveSession)(nil), + (*SessionEvent)(nil), + (*System)(nil), + (*Tag)(nil), + (*User)(nil), + } +} diff --git a/api/store/pg/entity/membership.go b/api/store/pg/entity/membership.go new file mode 100644 index 00000000000..48f8262b4cf --- /dev/null +++ b/api/store/pg/entity/membership.go @@ -0,0 +1,44 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/api/authorizer" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Membership struct { + bun.BaseModel `bun:"table:memberships"` + + UserID string `bun:"user_id,pk,type:uuid"` + NamespaceID string `bun:"namespace_id,pk,type:uuid"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + Status string `bun:"status"` + Role string `bun:"role"` + + User *User `bun:"rel:belongs-to,join:user_id=id"` + Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"` +} + +func MembershipFromModel(namespaceID string, member *models.Member) *Membership { + return &Membership{ + UserID: member.ID, + NamespaceID: namespaceID, + CreatedAt: member.AddedAt, + UpdatedAt: time.Time{}, + Status: string(member.Status), + Role: string(member.Role), + } +} + +func MembershipToModel(entity *Membership) *models.Member { + return &models.Member{ + ID: entity.UserID, + AddedAt: entity.CreatedAt, + Role: authorizer.Role(entity.Role), + Status: models.MemberStatus(entity.Status), + Email: entity.User.Email, + } +} diff --git a/api/store/pg/entity/namespace.go b/api/store/pg/entity/namespace.go new file mode 100644 index 00000000000..536445afdba --- /dev/null +++ b/api/store/pg/entity/namespace.go @@ -0,0 +1,77 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Namespace struct { + bun.BaseModel `bun:"table:namespaces"` + + ID string `bun:"id,pk,type:uuid"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + Type string `bun:"scope"` + Name string `bun:"name"` + OwnerID string `bun:"owner_id"` // TODO: Remove this column in the future, owner should be determined by membership role + Memberships []Membership `json:"members" bun:"rel:has-many,join:id=namespace_id"` + Settings NamespaceSettings `bun:"embed:"` +} + +type NamespaceSettings struct { + MaxDevices int `bun:"max_devices"` + SessionRecord bool `bun:"record_sessions"` + ConnectionAnnouncement string `bun:"connection_announcement,type:text"` +} + +func NamespaceFromModel(model *models.Namespace) *Namespace { + namespace := &Namespace{ + ID: model.TenantID, + CreatedAt: model.CreatedAt, + Type: string(model.Type), + Name: model.Name, + OwnerID: model.Owner, + Settings: NamespaceSettings{ + MaxDevices: model.MaxDevices, + SessionRecord: model.Settings.SessionRecord, + ConnectionAnnouncement: model.Settings.ConnectionAnnouncement, + }, + } + + namespace.Memberships = make([]Membership, len(model.Members)) + for i, member := range model.Members { + namespace.Memberships[i] = Membership{ + UserID: member.ID, + NamespaceID: model.TenantID, + CreatedAt: member.AddedAt, + Status: string(member.Status), + Role: string(member.Role), + } + } + + return namespace +} + +func NamespaceToModel(entity *Namespace) *models.Namespace { + namespace := &models.Namespace{ + TenantID: entity.ID, + Name: entity.Name, + Owner: entity.OwnerID, + CreatedAt: entity.CreatedAt, + Type: models.Type(entity.Type), + MaxDevices: entity.Settings.MaxDevices, + Settings: &models.NamespaceSettings{ + SessionRecord: entity.Settings.SessionRecord, + ConnectionAnnouncement: entity.Settings.ConnectionAnnouncement, + }, + } + + namespace.Members = make([]models.Member, len(entity.Memberships)) + for i, membership := range entity.Memberships { + namespace.Members[i] = *MembershipToModel(&membership) + } + + return namespace +} diff --git a/api/store/pg/entity/private-key.go b/api/store/pg/entity/private-key.go new file mode 100644 index 00000000000..d7e44f12c19 --- /dev/null +++ b/api/store/pg/entity/private-key.go @@ -0,0 +1,34 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type PrivateKey struct { + bun.BaseModel `bun:"table:private_keys"` + + Fingerprint string `bun:"fingerprint,pk"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + Data []byte `bun:"data,type:bytea"` +} + +func PrivateKeyFromModel(model *models.PrivateKey) *PrivateKey { + return &PrivateKey{ + Fingerprint: model.Fingerprint, + Data: model.Data, + CreatedAt: model.CreatedAt, + UpdatedAt: time.Time{}, + } +} + +func PrivateKeyToModel(entity *PrivateKey) *models.PrivateKey { + return &models.PrivateKey{ + Fingerprint: entity.Fingerprint, + Data: entity.Data, + CreatedAt: entity.CreatedAt, + } +} diff --git a/api/store/pg/entity/public-key.go b/api/store/pg/entity/public-key.go new file mode 100644 index 00000000000..ac031f50813 --- /dev/null +++ b/api/store/pg/entity/public-key.go @@ -0,0 +1,71 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type PublicKey struct { + bun.BaseModel `bun:"table:public_keys"` + + ID string `bun:"id,pk"` + Fingerprint string `bun:"fingerprint"` + NamespaceID string `bun:"namespace_id"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + Name string `bun:"name"` + Data []byte `bun:"data,type:bytea"` + + Tags []*Tag `bun:"m2m:public_key_tags,join:PublicKey=Tag"` +} + +func PublicKeyFromModel(model *models.PublicKey) *PublicKey { + publicKey := &PublicKey{ + NamespaceID: model.TenantID, + Fingerprint: model.Fingerprint, + CreatedAt: model.CreatedAt, + UpdatedAt: time.Time{}, + Name: model.PublicKeyFields.Name, + Data: model.Data, + Tags: []*Tag{}, + } + + if len(model.Filter.Tags) > 0 { + publicKey.Tags = make([]*Tag, len(model.Filter.Tags)) + for i, t := range model.Filter.Tags { + publicKey.Tags[i] = TagFromModel(&t) + } + } + + return publicKey +} + +func PublicKeyToModel(entity *PublicKey) *models.PublicKey { + publicKey := &models.PublicKey{ + TenantID: entity.NamespaceID, + Fingerprint: entity.Fingerprint, + Data: entity.Data, + CreatedAt: entity.CreatedAt, + PublicKeyFields: models.PublicKeyFields{ + Name: entity.Name, + Username: "", + Filter: models.PublicKeyFilter{ + Hostname: "", + Taggable: models.Taggable{ + Tags: []models.Tag{}, + }, + }, + }, + } + + if len(entity.Tags) > 0 { + publicKey.Filter.Tags = make([]models.Tag, len(entity.Tags)) + for i, t := range entity.Tags { + publicKey.Filter.Tags[i] = *TagToModel(t) + } + } + + return publicKey +} diff --git a/api/store/pg/entity/session.go b/api/store/pg/entity/session.go new file mode 100644 index 00000000000..4b052ce2868 --- /dev/null +++ b/api/store/pg/entity/session.go @@ -0,0 +1,207 @@ +package entity + +import ( + "encoding/json" + "strconv" + "strings" + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Session struct { + bun.BaseModel `bun:"table:sessions"` + + ID string `bun:"id,pk"` + DeviceID string `bun:"device_id"` + Username string `bun:"username"` + IPAddress string `bun:"ip_address"` + StartedAt time.Time `bun:"started_at"` + SeenAt time.Time `bun:"seen_at"` + Closed bool `bun:"closed"` + Authenticated bool `bun:"authenticated"` + Recorded bool `bun:"recorded"` + Type string `bun:"type"` + Term string `bun:"term"` + Longitude float64 `bun:"longitude"` + Latitude float64 `bun:"latitude"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + // EventTypes is a comma-separated list of unique event types + EventTypes string `bun:"event_types,scanonly"` + // EventSeats is a comma-separated list of unique seats as integers + EventSeats string `bun:"event_seats,scanonly"` + + Device *Device `bun:"rel:belongs-to,join:device_id=id"` +} + +func SessionFromModel(model *models.Session) *Session { + session := &Session{ + ID: model.UID, + DeviceID: string(model.DeviceUID), + Username: model.Username, + IPAddress: model.IPAddress, + StartedAt: model.StartedAt, + SeenAt: model.LastSeen, + Closed: model.Closed, + Authenticated: model.Authenticated, + Recorded: model.Recorded, + Type: model.Type, + Term: model.Term, + Longitude: model.Position.Longitude, + Latitude: model.Position.Latitude, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + return session +} + +func SessionToModel(entity *Session) *models.Session { + session := &models.Session{ + UID: entity.ID, + DeviceUID: models.UID(entity.DeviceID), + Username: entity.Username, + IPAddress: entity.IPAddress, + StartedAt: entity.StartedAt, + LastSeen: entity.SeenAt, + Closed: entity.Closed, + Authenticated: entity.Authenticated, + Recorded: entity.Recorded, + Type: entity.Type, + Term: entity.Term, + Position: models.SessionPosition{ + Longitude: entity.Longitude, + Latitude: entity.Latitude, + }, + Events: models.SessionEvents{ + Types: parseEventTypes(entity.EventTypes), + Seats: parseEventSeats(entity.EventSeats), + }, + } + + if entity.Device != nil { + session.Device = DeviceToModel(entity.Device) + session.TenantID = entity.Device.NamespaceID + } + + return session +} + +type ActiveSession struct { + bun.BaseModel `bun:"table:active_sessions"` + + SessionID string `bun:"session_id,pk"` + SeenAt time.Time `bun:"seen_at"` + CreatedAt time.Time `bun:"created_at"` + + Session *Session `bun:"rel:belongs-to,join:session_id=id"` +} + +func ActiveSessionFromModel(model *models.ActiveSession) *ActiveSession { + return &ActiveSession{ + SessionID: string(model.UID), + SeenAt: model.LastSeen, + CreatedAt: time.Now(), + } +} + +func ActiveSessionToModel(entity *ActiveSession) *models.ActiveSession { + activeSession := &models.ActiveSession{ + UID: models.UID(entity.SessionID), + LastSeen: entity.SeenAt, + } + + if entity.Session != nil && entity.Session.Device != nil { + activeSession.TenantID = entity.Session.Device.NamespaceID + } + + return activeSession +} + +type SessionEvent struct { + bun.BaseModel `bun:"table:session_events"` + + ID string `bun:"id,pk"` + SessionID string `bun:"session_id"` + Type string `bun:"type"` + Seat int `bun:"seat"` + Data string `bun:"data"` + CreatedAt time.Time `bun:"created_at"` + + Session *Session `bun:"rel:belongs-to,join:session_id=id"` +} + +func SessionEventFromModel(model *models.SessionEvent) *SessionEvent { + event := &SessionEvent{ + SessionID: model.Session, + Type: string(model.Type), + Seat: model.Seat, + CreatedAt: model.Timestamp, + } + + if model.Data != nil { + if dataBytes, err := json.Marshal(model.Data); err == nil { + event.Data = string(dataBytes) + } + } + + return event +} + +func SessionEventToModel(entity *SessionEvent) *models.SessionEvent { + event := &models.SessionEvent{ + Session: entity.SessionID, + Type: models.SessionEventType(entity.Type), + Timestamp: entity.CreatedAt, + Seat: entity.Seat, + } + + if entity.Data != "" { + var data interface{} + if err := json.Unmarshal([]byte(entity.Data), &data); err == nil { + event.Data = data + } + } + + return event +} + +// parseEventTypes converts a comma-separated string of event types into a slice of strings +func parseEventTypes(eventTypes string) []string { + if eventTypes == "" { + return []string{} + } + + types := strings.Split(eventTypes, ",") + result := make([]string, 0, len(types)) + + for _, t := range types { + if trimmed := strings.TrimSpace(t); trimmed != "" { + result = append(result, trimmed) + } + } + + return result +} + +// parseEventSeats converts a comma-separated string of seat numbers into a slice of integers +func parseEventSeats(eventSeats string) []int { + if eventSeats == "" { + return []int{} + } + + seats := strings.Split(eventSeats, ",") + result := make([]int, 0, len(seats)) + + for _, s := range seats { + if trimmed := strings.TrimSpace(s); trimmed != "" { + if seat, err := strconv.Atoi(trimmed); err == nil { + result = append(result, seat) + } + } + } + + return result +} diff --git a/api/store/pg/entity/system.go b/api/store/pg/entity/system.go new file mode 100644 index 00000000000..e39318af5fe --- /dev/null +++ b/api/store/pg/entity/system.go @@ -0,0 +1,121 @@ +package entity + +import ( + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type System struct { + bun.BaseModel `bun:"table:systems"` + + ID string `bun:"id,pk,type:uuid"` + Setup bool `bun:"setup"` + Authentication SystemAuthentication `bun:"embed:authentication_"` +} + +type SystemAuthentication struct { + Local SystemAuthenticationLocal `bun:"embed:local_"` + SAML SystemAuthenticationSAML `bun:"embed:saml_"` +} + +type SystemAuthenticationLocal struct { + Enabled bool `bun:"enabled"` +} + +type SystemAuthenticationSAML struct { + Enabled bool `bun:"enabled"` + Idp SystemIdpSAML `bun:"embed:idp_"` + Sp SystemSpSAML `bun:"embed:sp_"` +} + +type SystemAuthenticationBinding struct { + Post string `bun:"binding_post"` + Redirect string `bun:"binding_redirect"` + Preferred string `bun:"binding_preferred"` +} + +type SystemIdpSAML struct { + EntityID string `bun:"entity_id"` + Binding SystemAuthenticationBinding `bun:"embed:"` + Certificates []string `bun:"certificates,array"` + Mappings map[string]string `bun:"mappings,type:jsonb"` +} + +type SystemSpSAML struct { + SignAuthRequests bool `bun:"sign_auth_requests"` + Certificate string `bun:"certificate"` + PrivateKey string `bun:"private_key"` +} + +func SystemFromModel(model *models.System) *System { + if model == nil { + return &System{} + } + + entity := &System{ + Setup: model.Setup, + } + + if model.Authentication != nil { + if model.Authentication.Local != nil { + entity.Authentication.Local.Enabled = model.Authentication.Local.Enabled + } + + if model.Authentication.SAML != nil { + entity.Authentication.SAML.Enabled = model.Authentication.SAML.Enabled + + if model.Authentication.SAML.Idp != nil { + entity.Authentication.SAML.Idp.EntityID = model.Authentication.SAML.Idp.EntityID + entity.Authentication.SAML.Idp.Certificates = model.Authentication.SAML.Idp.Certificates + entity.Authentication.SAML.Idp.Mappings = model.Authentication.SAML.Idp.Mappings + + if model.Authentication.SAML.Idp.Binding != nil { + entity.Authentication.SAML.Idp.Binding.Post = model.Authentication.SAML.Idp.Binding.Post + entity.Authentication.SAML.Idp.Binding.Redirect = model.Authentication.SAML.Idp.Binding.Redirect + entity.Authentication.SAML.Idp.Binding.Preferred = model.Authentication.SAML.Idp.Binding.Preferred + } + } + + if model.Authentication.SAML.Sp != nil { + entity.Authentication.SAML.Sp.SignAuthRequests = model.Authentication.SAML.Sp.SignAuthRequests + entity.Authentication.SAML.Sp.Certificate = model.Authentication.SAML.Sp.Certificate + entity.Authentication.SAML.Sp.PrivateKey = model.Authentication.SAML.Sp.PrivateKey + } + } + } + + return entity +} + +func SystemToModel(entity *System) *models.System { + if entity == nil { + return &models.System{} + } + + return &models.System{ + Setup: entity.Setup, + Authentication: &models.SystemAuthentication{ + Local: &models.SystemAuthenticationLocal{ + Enabled: entity.Authentication.Local.Enabled, + }, + SAML: &models.SystemAuthenticationSAML{ + Enabled: entity.Authentication.SAML.Enabled, + Idp: &models.SystemIdpSAML{ + EntityID: entity.Authentication.SAML.Idp.EntityID, + Certificates: entity.Authentication.SAML.Idp.Certificates, + Mappings: entity.Authentication.SAML.Idp.Mappings, + Binding: &models.SystemAuthenticationBinding{ + Post: entity.Authentication.SAML.Idp.Binding.Post, + Redirect: entity.Authentication.SAML.Idp.Binding.Redirect, + Preferred: entity.Authentication.SAML.Idp.Binding.Preferred, + }, + }, + Sp: &models.SystemSpSAML{ + SignAuthRequests: entity.Authentication.SAML.Sp.SignAuthRequests, + Certificate: entity.Authentication.SAML.Sp.Certificate, + PrivateKey: entity.Authentication.SAML.Sp.PrivateKey, + }, + }, + }, + } +} diff --git a/api/store/pg/entity/tag.go b/api/store/pg/entity/tag.go new file mode 100644 index 00000000000..ba52c15897b --- /dev/null +++ b/api/store/pg/entity/tag.go @@ -0,0 +1,68 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Tag struct { + bun.BaseModel `bun:"table:tags"` + + ID string `bun:"id,pk"` + NamespaceID string `bun:"namespace_id"` + Name string `bun:"name"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + + Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"` +} + +type DeviceTag struct { + bun.BaseModel `bun:"table:device_tags"` + DeviceID string `bun:"device_id,pk"` + TagID string `bun:"tag_id,pk"` + CreatedAt time.Time `bun:"created_at"` + + Device *Device `bun:"rel:belongs-to,join:device_id=id"` + Tag *Tag `bun:"rel:belongs-to,join:tag_id=id"` +} + +type PublicKeyTag struct { + bun.BaseModel `bun:"table:public_key_tags"` + PublicKeyID string `bun:"public_key_id,pk"` + TagID string `bun:"tag_id,pk"` + CreatedAt time.Time `bun:"created_at"` + + PublicKey *PublicKey `bun:"rel:belongs-to,join:public_key_id=id"` + Tag *Tag `bun:"rel:belongs-to,join:tag_id=id"` +} + +func TagFromModel(model *models.Tag) *Tag { + return &Tag{ + ID: model.ID, + NamespaceID: model.TenantID, + Name: model.Name, + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + } +} + +func TagToModel(entity *Tag) *models.Tag { + return &models.Tag{ + ID: entity.ID, + TenantID: entity.NamespaceID, + Name: entity.Name, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} + +func NewDeviceTag(tagID, deviceID string) *DeviceTag { + return &DeviceTag{TagID: tagID, DeviceID: tagID} +} + +func NewPublicKeyTag(tagID, publickeyID string) *PublicKeyTag { + return &PublicKeyTag{TagID: tagID, PublicKeyID: tagID} +} diff --git a/api/store/pg/entity/user.go b/api/store/pg/entity/user.go new file mode 100644 index 00000000000..d427150ab7e --- /dev/null +++ b/api/store/pg/entity/user.go @@ -0,0 +1,109 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type User struct { + bun.BaseModel `bun:"table:users"` + + ID string `bun:"id,pk,type:uuid"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + LastLogin time.Time `bun:"last_login,nullzero"` + Origin string `bun:"origin"` + ExternalID string `bun:"external_id,nullzero"` + Status string `bun:"status"` + Name string `bun:"name"` + Username string `bun:"username"` + Email string `bun:"email"` + PasswordDigest string `bun:"password_digest"` + Preferences UserPreferences `bun:"embed:"` + MFA UserMFA `bun:"-"` +} + +type UserPreferences struct { + PreferredNamespace string `bun:"preferred_namespace_id,nullzero"` + AuthMethods []string `bun:"auth_methods,array"` + SecurityEmail string `bun:"security_email,nullzero"` + MaxNamespaces int `bun:"namespace_ownership_limit"` + EmailMarketing bool `bun:"email_marketing"` +} + +type UserMFA struct { + Enabled bool `bun:"enabled"` + Secret string `bun:"secret,nullzero"` + RecoveryCodes []string `bun:"recovery_codes,nullzero,array"` +} + +func UserFromModel(model *models.User) *User { + authMethods := make([]string, len(model.Preferences.AuthMethods)) + for i, method := range model.Preferences.AuthMethods { + authMethods[i] = method.String() + } + + return &User{ + ID: model.ID, + CreatedAt: model.CreatedAt, + UpdatedAt: time.Time{}, + LastLogin: model.LastLogin, + Origin: model.Origin.String(), + ExternalID: model.ExternalID, + Status: model.Status.String(), + Name: model.Name, + Username: model.Username, + Email: model.Email, + PasswordDigest: model.Password.Hash, + Preferences: UserPreferences{ + PreferredNamespace: model.Preferences.PreferredNamespace, + AuthMethods: authMethods, + SecurityEmail: model.UserData.RecoveryEmail, + MaxNamespaces: model.MaxNamespaces, + EmailMarketing: model.EmailMarketing, + }, + MFA: UserMFA{ + Enabled: model.MFA.Enabled, + Secret: model.MFA.Secret, + RecoveryCodes: model.MFA.RecoveryCodes, + }, + } +} + +func UserToModel(entity *User) *models.User { + authMethods := make([]models.UserAuthMethod, len(entity.Preferences.AuthMethods)) + for i, method := range entity.Preferences.AuthMethods { + authMethods[i] = models.UserAuthMethod(method) + } + + return &models.User{ + ID: entity.ID, + Origin: models.UserOrigin(entity.Origin), + ExternalID: entity.ExternalID, + Status: models.UserStatus(entity.Status), + MaxNamespaces: entity.Preferences.MaxNamespaces, + CreatedAt: entity.CreatedAt, + LastLogin: entity.LastLogin, + EmailMarketing: entity.Preferences.EmailMarketing, + UserData: models.UserData{ + Name: entity.Name, + Username: entity.Username, + Email: entity.Email, + RecoveryEmail: entity.Preferences.SecurityEmail, + }, + Password: models.UserPassword{ + Hash: entity.PasswordDigest, + }, + MFA: models.UserMFA{ + Enabled: entity.MFA.Enabled, + Secret: entity.MFA.Secret, + RecoveryCodes: entity.MFA.RecoveryCodes, + }, + Preferences: models.UserPreferences{ + PreferredNamespace: entity.Preferences.PreferredNamespace, + AuthMethods: authMethods, + }, + } +} diff --git a/api/store/pg/internal/filters.go b/api/store/pg/internal/filters.go new file mode 100644 index 00000000000..6461bd08ee3 --- /dev/null +++ b/api/store/pg/internal/filters.go @@ -0,0 +1,134 @@ +package internal + +import ( + "errors" + "slices" + "strconv" + "strings" + + "github.com/shellhub-io/shellhub/pkg/api/query" + "github.com/uptrace/bun" +) + +var ( + ErrUnsupportedContainsType = errors.New("unsupported value type for contains comparison") // ErrInvalidContainsValue is returned when a 'contains' filter has an unsupported value type + ErrUnsupportedBoolType = errors.New("unsupported value type for boolean conversion") // ErrUnsupportedBoolType is returned when a 'bool' filter receives an unsupported value type + ErrUnsupportedNumericType = errors.New("unsupported value type for numeric comparison") // ErrUnsupportedNumericType is returned when a 'gt' filter receives an unsupported value type +) + +// ParseFilterOperator converts a filter operator to its SQL representation. Supported operators are "AND" and "OR". +// It returns the SQL operator string and a boolean indicating if the operator is valid. +func ParseFilterOperator(op *query.FilterOperator) (string, bool) { + return strings.ToUpper(op.Name), slices.Contains([]string{"AND", "OR"}, strings.ToUpper(op.Name)) +} + +// ParseFilterProperty constructs the SQL representation of a property filter. It returns a SQL condition string, SQL +// arguments array, boolean indicating if the operator is valid and an error, if any +func ParseFilterProperty(fp *query.FilterProperty) (string, []any, bool, error) { + var condition string + var args []any + var err error + + switch fp.Operator { + case "contains": + condition, args, err = fromContains(fp.Name, fp.Value) + case "eq": + condition, args, err = fromEq(fp.Name, fp.Value) + case "bool": + condition, args, err = fromBool(fp.Name, fp.Value) + case "gt": + condition, args, err = fromGt(fp.Name, fp.Value) + case "ne": + condition, args, err = fromNe(fp.Name, fp.Value) + default: + return "", nil, false, nil + } + + if err != nil { + return "", nil, false, err + } + + return condition, args, true, nil +} + +// fromContains converts a "contains" JSON expression to an SQL expression. For strings, it uses ILIKE with '%value%' +// for case-insensitive substring matching. For arrays, it uses the @> (contains) operator to check if the column +// contains all the values in the array. Returns SQL condition string, arguments array, and error if any. +func fromContains(column string, value any) (string, []any, error) { + switch v := value.(type) { + case string: + return "? ILIKE ?", []any{bun.Ident(column), "%" + v + "%"}, nil + case []any: + return "? @> ?", []any{bun.Ident(column), v}, nil + } + + return "", nil, ErrUnsupportedContainsType +} + +// fromEq converts an "eq" (equals) JSON expression to an SQL expression using =. +// Returns SQL condition string, arguments array, and error if any. +func fromEq(column string, value any) (string, []any, error) { + return "? = ?", []any{bun.Ident(column), value}, nil +} + +// fromBool converts a "bool" JSON expression to an SQL expression. It handles various input types (int, string, bool) +// and converts them to boolean values. +// +// - For integers: 0 is false, anything else is true +// +// - For strings: uses strconv.ParseBool +// +// - For booleans: uses the value directly +// +// Returns SQL condition string, arguments array, and error if any. +func fromBool(column string, value any) (string, []any, error) { + var boolValue bool + + switch v := value.(type) { + case int: + boolValue = v != 0 + case string: + var err error + boolValue, err = strconv.ParseBool(v) + if err != nil { + return "", nil, err + } + case bool: + boolValue = v + default: + return "", nil, ErrUnsupportedBoolType + } + + return "? = ?", []any{bun.Ident(column), boolValue}, nil +} + +// fromGt converts a "gt" (greater than) JSON expression to an SQL expression using >. It handles various numeric types +// (int, float, etc.) and string representations of numbers. For strings, it attempts to convert to int first, then to +// float if int conversion fails. Returns SQL condition string, arguments array, and error if any. +func fromGt(column string, value any) (string, []any, error) { + switch v := value.(type) { + case uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64, float32, float64: + return "? > ?", []any{bun.Ident(column), v}, nil + case string: + var num any + var err error + + num, err = strconv.Atoi(v) + if err != nil { + num, err = strconv.ParseFloat(v, 64) + if err != nil { + return "", nil, err + } + } + + return "? > ?", []any{bun.Ident(column), num}, nil + default: + return "", nil, ErrUnsupportedNumericType + } +} + +// fromNe converts a "ne" (not equals) JSON expression to an SQL expression using <>. Returns SQL condition string, +// arguments array, and error if any. +func fromNe(column string, value any) (string, []any, error) { + return "? <> ?", []any{bun.Ident(column), value}, nil +} diff --git a/api/store/pg/member.go b/api/store/pg/member.go new file mode 100644 index 00000000000..288e66ea5ff --- /dev/null +++ b/api/store/pg/member.go @@ -0,0 +1,60 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" +) + +func (pg *Pg) NamespaceCreateMembership(ctx context.Context, tenantID string, membership *models.Member) error { + db := pg.getConnection(ctx) + + membership.AddedAt = clock.Now() + entity := entity.MembershipFromModel(tenantID, membership) + if _, err := db.NewInsert().Model(entity).Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) NamespaceUpdateMembership(ctx context.Context, tenantID string, member *models.Member) error { + db := pg.getConnection(ctx) + + e := entity.MembershipFromModel(tenantID, member) + e.UpdatedAt = clock.Now() + _, err := db.NewUpdate().Model(e).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func (pg *Pg) NamespaceDeleteMembership(ctx context.Context, tenantID string, member *models.Member) error { + db := pg.getConnection(ctx) + + e := entity.MembershipFromModel(tenantID, member) + r, err := db.NewDelete().Model(e).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if count, err := r.RowsAffected(); err != nil || count == 0 { + return store.ErrNoDocuments + } + + user := new(entity.User) + if err := db.NewSelect().Model(user).Where("id = ? AND preferred_namespace_id = ?", member.ID, tenantID).Limit(1).Scan(ctx); err != nil { + return fromSQLError(err) + } + + if user != nil && user.ID != "" { + user.Preferences.PreferredNamespace = "" + if _, err := db.NewUpdate().Model(user).Column("preferred_namespace_id").WherePK().Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + return nil +} diff --git a/api/store/pg/member_test.go b/api/store/pg/member_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/member_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/migrations/001_create_namespaces_table.go b/api/store/pg/migrations/001_create_namespaces_table.go new file mode 100644 index 00000000000..09f297818db --- /dev/null +++ b/api/store/pg/migrations/001_create_namespaces_table.go @@ -0,0 +1,53 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration001Up, migration001Down) +} + +func migration001Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS namespace_scope; + CREATE TYPE namespace_scope AS ENUM ('personal', 'team'); + `) + if err != nil { + return err + } + + table := &struct { + bun.BaseModel `bun:"table:namespaces"` + ID string `bun:"id,type:uuid,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + Scope string `bun:"scope,type:namespace_scope,notnull"` + Name string `bun:"name,type:varchar(64),notnull"` + OwnerID string `bun:"owner_id,type:uuid,notnull"` + MaxDevices int `bun:"max_devices,type:integer,notnull"` + RecordSessions bool `bun:"record_sessions,notnull"` + ConnectionAnnouncement string `bun:"connection_announcement,type:text,nullzero"` + }{} + + if _, err := db.NewCreateTable().Model(table).IfNotExists().Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 001") + + return err + } + + return nil +} + +func migration001Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS namespaces; + DROP TYPE IF EXISTS namespace_scope; + `) + + return err +} diff --git a/api/store/pg/migrations/002_create_users_table.go b/api/store/pg/migrations/002_create_users_table.go new file mode 100644 index 00000000000..25997de1a40 --- /dev/null +++ b/api/store/pg/migrations/002_create_users_table.go @@ -0,0 +1,73 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration002Up, migration002Down) +} + +func migration002Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS user_origin; + CREATE TYPE user_origin AS ENUM ('local', 'saml'); + + DROP TYPE IF EXISTS user_status; + CREATE TYPE user_status AS ENUM ('invited', 'pending', 'confirmed'); + + DROP TYPE IF EXISTS user_auth_method; + CREATE TYPE user_auth_method AS ENUM ('local', 'saml'); + `) + if err != nil { + return err + } + + table := &struct { + bun.BaseModel `bun:"table:users"` + ID string `bun:"id,type:uuid,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + LastLogin time.Time `bun:"last_login,type:timestamptz,nullzero"` + Origin string `bun:"origin,type:user_origin,notnull"` + ExternalID string `bun:"external_id,type:varchar,nullzero"` + Status string `bun:"status,type:user_status,notnull"` + Name string `bun:"name,type:varchar(64),notnull"` + Username string `bun:"username,type:varchar(32),notnull,unique"` + Email string `bun:"email,type:varchar(320),notnull,unique"` + SecurityEmail string `bun:"security_email,type:varchar(320),nullzero"` + PasswordDigest string `bun:"password_digest,type:char(72),notnull"` + AuthMethods []string `bun:"auth_methods,type:user_auth_method[],array,notnull"` + NamespaceOwnershipLimit int `bun:"namespace_ownership_limit,type:integer,notnull"` + EmailMarketing bool `bun:"email_marketing,notnull,default:false"` + PreferredNamespaceID string `bun:"preferred_namespace_id,type:uuid,nullzero"` + }{} + + if _, err := db. + NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("preferred_namespace_id") REFERENCES namespaces("id") ON DELETE SET NULL`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 002") + + return err + } + + return nil +} + +func migration002Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS users; + DROP TYPE IF EXISTS user_origin; + DROP TYPE IF EXISTS user_status; + DROP TYPE IF EXISTS user_auth_method; + `) + + return err +} diff --git a/api/store/pg/migrations/003_create_memberships_table.go b/api/store/pg/migrations/003_create_memberships_table.go new file mode 100644 index 00000000000..64572bfb191 --- /dev/null +++ b/api/store/pg/migrations/003_create_memberships_table.go @@ -0,0 +1,59 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration003Up, migration003Down) +} + +func migration003Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS membership_status; + CREATE TYPE membership_status AS ENUM ('pending', 'accepted'); + + DROP TYPE IF EXISTS membership_role; + CREATE TYPE membership_role AS ENUM ('owner', 'administrator', 'operator', 'observer'); + `) + if err != nil { + return err + } + + table := &struct { + bun.BaseModel `bun:"table:memberships"` + UserID string `bun:"user_id,type:uuid,notnull,pk"` + NamespaceID string `bun:"namespace_id,type:uuid,notnull,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + Status string `bun:"status,type:membership_status,notnull"` + Role string `bun:"role,type:membership_role,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("user_id") REFERENCES users("id") ON DELETE CASCADE`). + ForeignKey(`("namespace_id") REFERENCES namespaces("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 003") + + return err + } + + return nil +} + +func migration003Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS memberships; + DROP TYPE IF EXISTS membership_status; + DROP TYPE IF EXISTS membership_role; + `) + + return err +} diff --git a/api/store/pg/migrations/004_create_devices_table.go b/api/store/pg/migrations/004_create_devices_table.go new file mode 100644 index 00000000000..d82ad4d4be3 --- /dev/null +++ b/api/store/pg/migrations/004_create_devices_table.go @@ -0,0 +1,106 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration004Up, migration004Down) +} + +func migration004Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS device_status; + CREATE TYPE device_status AS ENUM ('accepted', 'pending', 'rejected', 'removed', 'unused'); + `) + if err != nil { + return err + } + + deviceTable := &struct { + bun.BaseModel `bun:"table:devices"` + ID string `bun:"id,type:varchar,pk"` + NamespaceID string `bun:"namespace_id,type:uuid,notnull"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + RemovedAt *time.Time `bun:"removed_at,type:timestamptz"` + SeenAt time.Time `bun:"seen_at,type:timestamptz,notnull"` + DisconnectedAt time.Time `bun:"disconnected_at,type:timestamptz,nullzero"` + Status string `bun:"status,type:device_status,notnull"` + Name string `bun:"name,type:varchar(64),notnull"` + Mac string `bun:"mac,type:varchar(17),notnull"` + PublicKey string `bun:"public_key,type:text,notnull"` + Identifier string `bun:"identifier,type:varchar,nullzero"` + PrettyName string `bun:"pretty_name,type:varchar(64),nullzero"` + Version string `bun:"version,type:varchar(32),nullzero"` + Arch string `bun:"arch,type:varchar(16),nullzero"` + Platform string `bun:"platform,type:varchar(32),nullzero"` + Latitude float64 `bun:"latitude,type:numeric,nullzero"` + Longitude float64 `bun:"longitude,type:numeric,nullzero"` + }{} + + _, err = db.NewCreateTable(). + Model(deviceTable). + IfNotExists(). + ForeignKey(`("namespace_id") REFERENCES namespaces("id") ON DELETE CASCADE`). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to apply migration 004") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:devices"` + })(nil)). + Index("devices_namespace_id"). + Column("namespace_id"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to apply migration 004") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:devices"` + })(nil)). + Index("devices_seen_at"). + Column("seen_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to apply migration 004") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:devices"` + })(nil)). + Index("devices_disconnected_at"). + Column("disconnected_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to apply migration 004") + + return err + } + + return nil +} + +func migration004Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS devices; + DROP TYPE IF EXISTS device_status; + `) + + return err +} diff --git a/api/store/pg/migrations/005_create_private_keys_table.go b/api/store/pg/migrations/005_create_private_keys_table.go new file mode 100644 index 00000000000..75d34ee22a1 --- /dev/null +++ b/api/store/pg/migrations/005_create_private_keys_table.go @@ -0,0 +1,39 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration005Up, migration005Down) +} + +func migration005Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:private_keys"` + Fingerprint string `bun:"fingerprint,type:varchar,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + Data []byte `bun:"data,type:bytea,nullzero"` + }{} + + if _, err := db.NewCreateTable().Model(table).IfNotExists().Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 005") + + return err + } + + return nil +} + +func migration005Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS private_keys; + `) + + return err +} diff --git a/api/store/pg/migrations/006_create_api_keys_table.go b/api/store/pg/migrations/006_create_api_keys_table.go new file mode 100644 index 00000000000..5b8326b8193 --- /dev/null +++ b/api/store/pg/migrations/006_create_api_keys_table.go @@ -0,0 +1,47 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration006Up, migration006Down) +} + +func migration006Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:api_keys"` + KeyDigest string `bun:"key_digest,type:char(64),notnull,pk"` + NamespaceID string `bun:"namespace_id,type:uuid,notnull,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + ExpiresIn int64 `bun:"expires_in,type:bigint,nullzero"` + Name string `bun:"name,type:varchar,notnull,unique"` + Role string `bun:"role,type:membership_role,notnull"` + UserID string `bun:"user_id,type:uuid,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("namespace_id") REFERENCES namespaces("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 006") + + return err + } + + return nil +} + +func migration006Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS api_keys; + `) + + return err +} diff --git a/api/store/pg/migrations/007_create_public_keys_table.go b/api/store/pg/migrations/007_create_public_keys_table.go new file mode 100644 index 00000000000..c88bd9a4e5e --- /dev/null +++ b/api/store/pg/migrations/007_create_public_keys_table.go @@ -0,0 +1,46 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration007Up, migration007Down) +} + +func migration007Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:public_keys"` + ID string `bun:"id,type:uuid,pk"` + Fingerprint string `bun:"fingerprint,type:varchar,notnull"` + NamespaceID string `bun:"namespace_id,type:uuid,notnull"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + Name string `bun:"name,type:varchar,notnull"` + Data []byte `bun:"data,type:bytea,nullzero"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("namespace_id") REFERENCES namespaces("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 007") + + return err + } + + return nil +} + +func migration007Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS public_keys; + `) + + return err +} diff --git a/api/store/pg/migrations/008_create_tags_table.go b/api/store/pg/migrations/008_create_tags_table.go new file mode 100644 index 00000000000..92dc305aa2f --- /dev/null +++ b/api/store/pg/migrations/008_create_tags_table.go @@ -0,0 +1,56 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration008Up, migration008Down) +} + +func migration008Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:tags"` + ID string `bun:"id,type:uuid,pk"` + NamespaceID string `bun:"namespace_id,type:uuid,notnull"` + Name string `bun:"name,type:varchar,notnull"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("namespace_id") REFERENCES namespaces("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 008") + + return err + } + + _, err := db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:tags"` + })(nil)). + Index("tags_namespace_id_name_unique"). + Column("namespace_id", "name"). + Unique(). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to apply migration 008") + + return err + } + + return nil +} + +func migration008Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS tags") + + return err +} diff --git a/api/store/pg/migrations/009_create_device_tags_table.go b/api/store/pg/migrations/009_create_device_tags_table.go new file mode 100644 index 00000000000..ef03901e049 --- /dev/null +++ b/api/store/pg/migrations/009_create_device_tags_table.go @@ -0,0 +1,64 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration009Up, migration009Down) +} + +func migration009Up(ctx context.Context, db *bun.DB) error { + deviceTagsTable := &struct { + bun.BaseModel `bun:"table:device_tags"` + DeviceID string `bun:"device_id,type:varchar,pk"` + TagID string `bun:"tag_id,type:uuid,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(deviceTagsTable). + IfNotExists(). + ForeignKey(`("device_id") REFERENCES devices("id") ON DELETE CASCADE`). + ForeignKey(`("tag_id") REFERENCES tags("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create device_tags table in migration 009") + + return err + } + + if _, err := db.NewCreateIndex(). + Model(deviceTagsTable). + Index("device_tags_device_id"). + Column("device_id"). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create device_id index for device_tags in migration 009") + + return err + } + + if _, err := db.NewCreateIndex(). + Model(deviceTagsTable). + Index("device_tags_tag_id"). + Column("tag_id"). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create tag_id index for device_tags in migration 009") + + return err + } + + return nil +} + +func migration009Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS public_key_tags; + DROP TABLE IF EXISTS device_tags; + `) + + return err +} diff --git a/api/store/pg/migrations/010_create_public_key_tags_table.go b/api/store/pg/migrations/010_create_public_key_tags_table.go new file mode 100644 index 00000000000..fcc56a39fc7 --- /dev/null +++ b/api/store/pg/migrations/010_create_public_key_tags_table.go @@ -0,0 +1,61 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration010Up, migration010Down) +} + +func migration010Up(ctx context.Context, db *bun.DB) error { + publicKeyTagsTable := &struct { + bun.BaseModel `bun:"table:public_key_tags"` + PublicKeyID string `bun:"public_key_id,type:uuid,pk"` + TagID string `bun:"tag_id,type:uuid,pk"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(publicKeyTagsTable). + IfNotExists(). + ForeignKey(`("public_key_id") REFERENCES public_keys("id") ON DELETE CASCADE`). + ForeignKey(`("tag_id") REFERENCES tags("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create public_key_tags table in migration 010") + + return err + } + + if _, err := db.NewCreateIndex(). + Model(publicKeyTagsTable). + Index("public_key_tags_public_key_id"). + Column("public_key_id"). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create public_key_id index for public_key_tags in migration 010") + + return err + } + + if _, err := db.NewCreateIndex(). + Model(publicKeyTagsTable). + Index("public_key_tags_tag_id"). + Column("tag_id"). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to create tag_id index for public_key_tags in migration 010") + + return err + } + + return nil +} + +func migration010Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS public_key_tags") + + return err +} diff --git a/api/store/pg/migrations/011_create_systems_table.go b/api/store/pg/migrations/011_create_systems_table.go new file mode 100644 index 00000000000..7d23b1e0ebb --- /dev/null +++ b/api/store/pg/migrations/011_create_systems_table.go @@ -0,0 +1,50 @@ +package migrations + +import ( + "context" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration011Up, migration011Down) +} + +func migration011Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:systems"` + + ID string `bun:"id,type:uuid,pk"` + Setup bool `bun:"setup,notnull,default:false"` + AuthenticationLocalEnabled bool `bun:"authentication_local_enabled,notnull,default:true"` + AuthenticationSamlEnabled bool `bun:"authentication_saml_enabled,notnull,default:false"` + AuthenticationSamlIdpEntityID string `bun:"authentication_saml_idp_entity_id,type:text,nullzero"` + AuthenticationSamlIdpBindingPost string `bun:"authentication_saml_idp_binding_post,type:text,nullzero"` + AuthenticationSamlIdpBindingRedirect string `bun:"authentication_saml_idp_binding_redirect,type:text,nullzero"` + AuthenticationSamlIdpBindingPreferred string `bun:"authentication_saml_idp_binding_preferred,type:text,nullzero"` + AuthenticationSamlIdpCertificates []string `bun:"authentication_saml_idp_certificates,array,nullzero"` + AuthenticationSamlIdpMappings map[string]string `bun:"authentication_saml_idp_mappings,type:jsonb,nullzero"` + AuthenticationSamlSpSignAuthRequests bool `bun:"authentication_saml_sp_sign_auth_requests,notnull,default:false"` + AuthenticationSamlSpCertificate string `bun:"authentication_saml_sp_certificate,type:text,nullzero"` + AuthenticationSamlSpPrivateKey string `bun:"authentication_saml_sp_private_key,type:text,nullzero"` + }{} + + if _, err := db. + NewCreateTable(). + Model(table). + IfNotExists(). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 011") + + return err + } + + return nil +} + +func migration011Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, `DROP TABLE IF EXISTS systems;`) + + return err +} diff --git a/api/store/pg/migrations/012_create_sessions_table.go b/api/store/pg/migrations/012_create_sessions_table.go new file mode 100644 index 00000000000..d28ecc16650 --- /dev/null +++ b/api/store/pg/migrations/012_create_sessions_table.go @@ -0,0 +1,128 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration012Up, migration012Down) +} + +func migration012Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS session_type; + CREATE TYPE session_type AS ENUM ('shell', 'exec'); + `) + if err != nil { + return err + } + + table := &struct { + bun.BaseModel `bun:"table:sessions"` + ID string `bun:"id,type:uuid,pk"` + DeviceID string `bun:"device_id,type:varchar,notnull"` + Username string `bun:"username,type:varchar(64),notnull"` + IPAddress string `bun:"ip_address,type:inet,notnull"` + StartedAt time.Time `bun:"started_at,type:timestamptz,notnull"` + SeenAt time.Time `bun:"seen_at,type:timestamptz,notnull"` + Closed bool `bun:"closed,notnull,default:false"` + Authenticated bool `bun:"authenticated,notnull,default:false"` + Recorded bool `bun:"recorded,notnull,default:false"` + Type string `bun:"type,type:session_type,nullzero"` + Term string `bun:"term,type:varchar(32),nullzero"` + Longitude float64 `bun:"longitude,type:numeric(10,7),nullzero"` + Latitude float64 `bun:"latitude,type:numeric(10,7),nullzero"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + UpdatedAt time.Time `bun:"updated_at,type:timestamptz,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("device_id") REFERENCES devices("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 012") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:sessions"` + })(nil)). + Index("sessions_device_id_idx"). + Column("device_id"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create device_id index in migration 012") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:sessions"` + })(nil)). + Index("sessions_started_at_idx"). + Column("started_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create started_at index in migration 012") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:sessions"` + })(nil)). + Index("sessions_username_idx"). + Column("username"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create username index in migration 012") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:sessions"` + })(nil)). + Index("sessions_type_idx"). + Column("type"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create type index in migration 012") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:sessions"` + })(nil)). + Index("sessions_closed_started_idx"). + Column("closed", "started_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create closed_started_at index in migration 012") + + return err + } + + return nil +} + +func migration012Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS sessions; + DROP TYPE IF EXISTS session_type; + `) + + return err +} diff --git a/api/store/pg/migrations/013_create_active_sessions_table.go b/api/store/pg/migrations/013_create_active_sessions_table.go new file mode 100644 index 00000000000..59cddbbe86b --- /dev/null +++ b/api/store/pg/migrations/013_create_active_sessions_table.go @@ -0,0 +1,40 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration013Up, migration013Down) +} + +func migration013Up(ctx context.Context, db *bun.DB) error { + table := &struct { + bun.BaseModel `bun:"table:active_sessions"` + SessionID string `bun:"session_id,type:uuid,pk"` + SeenAt time.Time `bun:"seen_at,type:timestamptz,notnull"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("session_id") REFERENCES sessions("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 013") + + return err + } + + return nil +} + +func migration013Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, `DROP TABLE IF EXISTS active_sessions;`) + + return err +} diff --git a/api/store/pg/migrations/014_create_session_events_table.go b/api/store/pg/migrations/014_create_session_events_table.go new file mode 100644 index 00000000000..ba657205d65 --- /dev/null +++ b/api/store/pg/migrations/014_create_session_events_table.go @@ -0,0 +1,106 @@ +package migrations + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func init() { + migrations.MustRegister(migration014Up, migration014Down) +} + +func migration014Up(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TYPE IF EXISTS session_event_type; + CREATE TYPE session_event_type AS ENUM ( + 'pty-output', 'pty-req', 'window-change', 'exit-code', + 'exit-status', 'exit-signal', 'env', 'shell', 'exec', + 'subsystem', 'signal', 'tcpip-forward', 'auth-agent-req' + ); + `) + if err != nil { + return err + } + + table := &struct { + bun.BaseModel `bun:"table:session_events"` + ID string `bun:"id,type:uuid,pk"` + SessionID string `bun:"session_id,type:uuid,notnull"` + Type string `bun:"type,type:session_event_type,notnull"` + Seat int `bun:"seat,type:integer,notnull"` + Data string `bun:"data,type:jsonb,nullzero"` + CreatedAt time.Time `bun:"created_at,type:timestamptz,notnull,default:now()"` + }{} + + if _, err := db.NewCreateTable(). + Model(table). + IfNotExists(). + ForeignKey(`("session_id") REFERENCES sessions("id") ON DELETE CASCADE`). + Exec(ctx); err != nil { + log.WithError(err).Error("failed to apply migration 014") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:session_events"` + })(nil)). + Index("session_events_session_id_created_at_idx"). + Column("session_id", "created_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create session_id_created_at index in migration 014") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:session_events"` + })(nil)). + Index("session_events_type_created_at_idx"). + Column("type", "created_at"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create type_created_at index in migration 014") + + return err + } + + _, err = db.NewCreateIndex(). + Model((*struct { + bun.BaseModel `bun:"table:session_events"` + })(nil)). + Index("session_events_seat_idx"). + Column("seat"). + Exec(ctx) + if err != nil { + log.WithError(err).Error("failed to create seat index in migration 014") + + return err + } + + _, err = db.ExecContext(ctx, ` + CREATE INDEX session_events_data_gin_idx ON session_events USING GIN (data); + `) + if err != nil { + log.WithError(err).Error("failed to create data GIN index in migration 014") + + return err + } + + return nil +} + +func migration014Down(ctx context.Context, db *bun.DB) error { + _, err := db.ExecContext(ctx, ` + DROP TABLE IF EXISTS session_events; + DROP TYPE IF EXISTS session_event_type; + `) + + return err +} diff --git a/api/store/pg/migrations/migrations.go b/api/store/pg/migrations/migrations.go new file mode 100644 index 00000000000..9162d7f92c7 --- /dev/null +++ b/api/store/pg/migrations/migrations.go @@ -0,0 +1,15 @@ +package migrations + +import ( + "github.com/uptrace/bun/migrate" +) + +var migrations = migrate.NewMigrations() + +func FetchMigrations() (*migrate.Migrations, error) { + if err := migrations.DiscoverCaller(); err != nil { + return nil, err + } + + return migrations, nil +} diff --git a/api/store/pg/namespace.go b/api/store/pg/namespace.go new file mode 100644 index 00000000000..1064e36c608 --- /dev/null +++ b/api/store/pg/namespace.go @@ -0,0 +1,233 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" + "github.com/uptrace/bun" +) + +func (pg *Pg) NamespaceCreate(ctx context.Context, namespace *models.Namespace) (string, error) { + db := pg.getConnection(ctx) + + if namespace.TenantID == "" { + namespace.TenantID = uuid.Generate() + } + + namespace.CreatedAt = clock.Now() + + if _, err := db.NewInsert().Model(entity.NamespaceFromModel(namespace)).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return namespace.TenantID, nil +} + +func (pg *Pg) NamespaceConflicts(ctx context.Context, target *models.NamespaceConflicts) ([]string, bool, error) { + db := pg.getConnection(ctx) + + namespaces := make([]map[string]any, 0) + if err := db.NewSelect().Model((*entity.Namespace)(nil)).Column("name").Where("name = ?", target.Name).Scan(ctx, &namespaces); err != nil { + return nil, false, fromSQLError(err) + } + + conflicts := make([]string, 0) + for _, user := range namespaces { + if user["name"] == target.Name { + conflicts = append(conflicts, "name") + } + } + + return conflicts, len(conflicts) > 0, nil +} + +func (pg *Pg) NamespaceList(ctx context.Context, opts ...store.QueryOption) ([]models.Namespace, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.Namespace, 0) + query := db.NewSelect().Model(&entities).Relation("Memberships.User") + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + namespaces := make([]models.Namespace, len(entities)) + for i, e := range entities { + namespaces[i] = *entity.NamespaceToModel(&e) + } + + return namespaces, count, nil +} + +func (pg *Pg) NamespaceResolve(ctx context.Context, resolver store.NamespaceResolver, val string) (*models.Namespace, error) { + db := pg.getConnection(ctx) + + column, err := NamespaceResolverToString(resolver) + if err != nil { + return nil, err + } + + ns := new(entity.Namespace) + query := db.NewSelect().Model(ns).Relation("Memberships.User").Where("? = ?", bun.Ident(column), val) + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.NamespaceToModel(ns), nil +} + +func (pg *Pg) NamespaceGetPreferred(ctx context.Context, userID string) (*models.Namespace, error) { + db := pg.getConnection(ctx) + + ns := new(entity.Namespace) + if err := db.NewSelect(). + Model(ns). + Relation("Memberships.User"). + Join("JOIN users"). + JoinOn("namespace.id = users.preferred_namespace_id OR namespace.id IN (SELECT namespace_id FROM memberships WHERE user_id = users.id)"). + Where("users.id = ?", userID). + OrderExpr("CASE WHEN namespace.id = users.preferred_namespace_id THEN 0 ELSE 1 END"). + Limit(1). + Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.NamespaceToModel(ns), nil +} + +func (pg *Pg) NamespaceUpdate(ctx context.Context, namespace *models.Namespace) error { + db := pg.getConnection(ctx) + + n := entity.NamespaceFromModel(namespace) + n.UpdatedAt = clock.Now() + + _, err := db.NewUpdate().Model(n).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func (pg *Pg) NamespaceIncrementDeviceCount(ctx context.Context, tenantID string, status models.DeviceStatus, count int64) error { + db := pg.getConnection(ctx) + + column := "devices" + string(status) + "count" + result, err := db.NewUpdate(). + Model((*entity.Namespace)(nil)). + Set("? = ? + ?", bun.Ident(column), bun.Ident(column), count). + Where("id = ?", tenantID). + Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fromSQLError(err) + } + + if rowsAffected == 0 { + return store.ErrNoDocuments + } + + return nil +} + +func (pg *Pg) NamespaceDelete(ctx context.Context, namespace *models.Namespace) error { + deletedCount, err := pg.NamespaceDeleteMany(ctx, []string{namespace.TenantID}) + switch { + case err != nil: + return err + case deletedCount < 1: + return store.ErrNoDocuments + default: + return nil + } +} + +func (pg *Pg) NamespaceDeleteMany(ctx context.Context, tenantIDs []string) (int64, error) { + db := pg.getConnection(ctx) + fn := pg.namespaceDeleteManyFn(ctx, tenantIDs) + + if tx, ok := db.(bun.Tx); ok { + return fn(tx) + } else { // nolint:revive + tx, err := pg.driver.BeginTx(ctx, nil) + if err != nil { + return 0, fromSQLError(err) + } + + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) + } + }() + + count, err := fn(tx) + if err != nil { + _ = tx.Rollback() + + return 0, err + } + + if err := tx.Commit(); err != nil { + return 0, fromSQLError(err) + } + + return count, nil + } +} + +func (pg *Pg) namespaceDeleteManyFn(ctx context.Context, tenantIDs []string) func(tx bun.Tx) (int64, error) { + return func(tx bun.Tx) (int64, error) { + res, err := tx.NewDelete().Model((*entity.Namespace)(nil)).Where("id IN (?)", bun.In(tenantIDs)).Exec(ctx) + if err != nil { + return 0, fromSQLError(err) + } + + count, _ := res.RowsAffected() + + entities := []any{ + (*entity.Device)(nil), + // (*entity.Session)(nil), + // (*entity.FirewallRule)(nil), + (*entity.PublicKey)(nil), + // (*entity.RecordedSession)(nil), + (*entity.APIKey)(nil), + } + + for _, e := range entities { + if _, err := tx.NewDelete().Model(e).Where("namespace_id IN (?)", bun.In(tenantIDs)).Exec(ctx); err != nil { + return 0, fromSQLError(err) + } + } + + if _, err := tx.NewUpdate(). + Model((*entity.User)(nil)). + Set("preferred_namespace_id = NULL"). + Where("preferred_namespace_id IN (?)", bun.In(tenantIDs)). + Exec(ctx); err != nil { + return 0, fromSQLError(err) + } + + return count, nil + } +} + +func NamespaceResolverToString(resolver store.NamespaceResolver) (string, error) { + switch resolver { + case store.NamespaceTenantIDResolver: + return "id", nil + case store.NamespaceNameResolver: + return "name", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/namespace_test.go b/api/store/pg/namespace_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/namespace_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/options/log.go b/api/store/pg/options/log.go new file mode 100644 index 00000000000..f2601b32b7c --- /dev/null +++ b/api/store/pg/options/log.go @@ -0,0 +1,39 @@ +package options + +import ( + "context" + "os" + + "github.com/oiime/logrusbun" + "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func Log(level string, verbose bool) Option { + return func(ctx context.Context, db *bun.DB) error { + level, err := logrus.ParseLevel(level) + if err != nil { + return err + } + + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: new(logrus.TextFormatter), + Hooks: make(logrus.LevelHooks), + Level: level, + } + + db.AddQueryHook(logrusbun.NewQueryHook( + logrusbun.WithEnabled(true), + logrusbun.WithVerbose(verbose), + logrusbun.WithQueryHookOptions(logrusbun.QueryHookOptions{ + Logger: logger, + QueryLevel: logrus.DebugLevel, + ErrorLevel: logrus.ErrorLevel, + SlowLevel: logrus.WarnLevel, + }), + )) + + return nil + } +} diff --git a/api/store/pg/options/migrate.go b/api/store/pg/options/migrate.go new file mode 100644 index 00000000000..e6270a3cf61 --- /dev/null +++ b/api/store/pg/options/migrate.go @@ -0,0 +1,61 @@ +package options + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store/pg/migrations" + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate" +) + +func Migrate() Option { + return func(ctx context.Context, db *bun.DB) error { + log.Info("starting database migration") + + migrations, err := migrations.FetchMigrations() + if err != nil { + log.WithError(err).Error("failed to fetch migrations") + + return err + } + + migrator := migrate.NewMigrator(db, migrations) + if err := migrator.Init(context.Background()); err != nil { + log.WithError(err).Error("failed to start migrations tables") + + return err + } + + if err := migrator.Lock(ctx); err != nil { + log.WithError(err).Error("failed to acquire migration lock") + + return err + } + + defer func() { + if err := migrator.Unlock(ctx); err != nil { + log.WithError(err).Error("failed to release migration lock") + } else { + log.Debug("migration lock released successfully") + } + }() + + group, err := migrator.Migrate(ctx) + if err != nil { + log.WithError(err).Error("migration failed") + + return err + } + + if group.IsZero() { + log.Info("no new migrations to run (database is up to date)") + + return nil + } + + log.Info("migration completed successfully") + + return nil + } +} diff --git a/api/store/pg/options/options.go b/api/store/pg/options/options.go new file mode 100644 index 00000000000..e166dd55e24 --- /dev/null +++ b/api/store/pg/options/options.go @@ -0,0 +1,9 @@ +package options + +import ( + "context" + + "github.com/uptrace/bun" +) + +type Option func(ctx context.Context, db *bun.DB) error diff --git a/api/store/pg/pg.go b/api/store/pg/pg.go new file mode 100644 index 00000000000..b620305c2c6 --- /dev/null +++ b/api/store/pg/pg.go @@ -0,0 +1,58 @@ +package pg + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/api/store/pg/options" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +type queryOptions struct{} + +type Pg struct { + driver *bun.DB + options *queryOptions +} + +func URI(host, port, user, password, db string) string { + return fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, db) +} + +func New(ctx context.Context, uri string, opts ...options.Option) (store.Store, error) { + config, err := pgxpool.ParseConfig(uri) + if err != nil { + return nil, err + } + + config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, err + } + + pg := &Pg{driver: bun.NewDB(stdlib.OpenDBFromPool(pool), pgdialect.New()), options: &queryOptions{}} + if err := pg.driver.Ping(); err != nil { + return nil, err + } + + pg.driver.RegisterModel(entity.Entities()...) // We need to register models so we can apply fixtures and relations later + for _, opt := range opts { + if err := opt(ctx, pg.driver); err != nil { + return nil, err + } + } + + return pg, nil +} + +func (pg *Pg) Driver() *bun.DB { + return pg.driver +} diff --git a/api/store/pg/pg_test.go b/api/store/pg/pg_test.go new file mode 100644 index 00000000000..c51569fd0ad --- /dev/null +++ b/api/store/pg/pg_test.go @@ -0,0 +1,82 @@ +package pg_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg" + "github.com/shellhub-io/shellhub/api/store/pg/dbtest" + "github.com/shellhub-io/shellhub/api/store/pg/options" + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +var ( + srv = (*dbtest.Server)(nil) + s = (store.Store)(nil) + driver = (*bun.DB)(nil) +) + +func TestMain(m *testing.M) { + log.Info("Starting store tests") + + ctx := context.Background() + + srv = &dbtest.Server{} + + if err := srv.Up(ctx); err != nil { + log.WithError(err).Error("Failed to UP the postgres container") + os.Exit(1) + } + + c, err := srv.ConnectionString(ctx) + if err != nil { + log.WithError(err).Error("Failed to parse postgres connection string") + } + + log.Info("Connecting to ", c) + + s, err = pg.New(ctx, c, options.Migrate()) + if err != nil { + log.WithError(err).Error("Failed to create the postgres store") + os.Exit(1) + } + + driver, err = connectBun(ctx, c) + if err != nil { + log.WithError(err).Error("Failed to create a test driver") + os.Exit(1) + } + + code := m.Run() + + log.Info("Stopping store tests") + if err := srv.Down(ctx); err != nil { + log.WithError(err).Error("Failed to DOWN the postgres container") + os.Exit(1) + } + + os.Exit(code) +} + +func connectBun(ctx context.Context, uri string) (*bun.DB, error) { + config, err := pgxpool.ParseConfig(uri) + if err != nil { + return nil, err + } + + config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, err + } + + return bun.NewDB(stdlib.OpenDBFromPool(pool), pgdialect.New()), nil +} diff --git a/api/store/pg/private-key.go b/api/store/pg/private-key.go new file mode 100644 index 00000000000..f07c55a6999 --- /dev/null +++ b/api/store/pg/private-key.go @@ -0,0 +1,32 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" +) + +func (pg *Pg) PrivateKeyCreate(ctx context.Context, privateKey *models.PrivateKey) error { + db := pg.getConnection(ctx) + + privateKey.CreatedAt = clock.Now() + + if _, err := db.NewInsert().Model(entity.PrivateKeyFromModel(privateKey)).Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) PrivateKeyGet(ctx context.Context, fingerprint string) (*models.PrivateKey, error) { + db := pg.getConnection(ctx) + + privateKey := new(entity.PrivateKey) + if err := db.NewSelect().Model(privateKey).Where("fingerprint = ?", fingerprint).Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.PrivateKeyToModel(privateKey), nil +} diff --git a/api/store/pg/private-key_test.go b/api/store/pg/private-key_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/private-key_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/public-key.go b/api/store/pg/public-key.go new file mode 100644 index 00000000000..f12c2493e96 --- /dev/null +++ b/api/store/pg/public-key.go @@ -0,0 +1,98 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" + "github.com/uptrace/bun" +) + +func (pg *Pg) PublicKeyCreate(ctx context.Context, publicKey *models.PublicKey) (string, error) { + db := pg.getConnection(ctx) + + publicKey.CreatedAt = clock.Now() + e := entity.PublicKeyFromModel(publicKey) + e.ID = uuid.Generate() + + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return e.ID, nil // TODO: ID no model +} + +func (pg *Pg) PublicKeyList(ctx context.Context, opts ...store.QueryOption) ([]models.PublicKey, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.PublicKey, 0) + + query := db.NewSelect().Model(&entities) + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + publicKeys := make([]models.PublicKey, len(entities)) + for i, e := range entities { + publicKeys[i] = *entity.PublicKeyToModel(&e) + } + + return publicKeys, count, nil +} + +func (pg *Pg) PublicKeyUpdate(ctx context.Context, publicKey *models.PublicKey) error { + db := pg.getConnection(ctx) + + a := entity.PublicKeyFromModel(publicKey) + a.UpdatedAt = clock.Now() + _, err := db.NewUpdate().Model(a).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func (pg *Pg) PublicKeyResolve(ctx context.Context, resolver store.PublicKeyResolver, value string, opts ...store.QueryOption) (*models.PublicKey, error) { + db := pg.getConnection(ctx) + + column, err := PublicKeyResolverToString(resolver) + if err != nil { + return nil, err + } + + a := new(entity.PublicKey) + query := db.NewSelect().Model(a).Where("? = ?", bun.Ident(column), value) + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.PublicKeyToModel(a), nil +} + +func (pg *Pg) PublicKeyDelete(ctx context.Context, publicKey *models.PublicKey) error { + db := pg.getConnection(ctx) + + a := entity.PublicKeyFromModel(publicKey) + _, err := db.NewDelete().Model(a).WherePK().Exec(ctx) + + return fromSQLError(err) +} + +func PublicKeyResolverToString(resolver store.PublicKeyResolver) (string, error) { + switch resolver { + case store.PublicKeyFingerprintResolver: + return "fingerprint", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/public-key_test.go b/api/store/pg/public-key_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/public-key_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/query-options.go b/api/store/pg/query-options.go new file mode 100644 index 00000000000..414850db92d --- /dev/null +++ b/api/store/pg/query-options.go @@ -0,0 +1,157 @@ +package pg + +import ( + "context" + "errors" + "strings" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/internal" + "github.com/shellhub-io/shellhub/pkg/api/query" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +// ErrQueryNotFound is returned when the query context value is not found or has the wrong type +var ErrQueryNotFound = errors.New("query not found in context") + +func (pg *Pg) Options() store.QueryOptions { + return pg.options +} + +func (*queryOptions) Paginate(page *query.Paginator) store.QueryOption { + return func(ctx context.Context) error { + query, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + query = query.Offset(page.PerPage * (page.Page - 1)).Limit(page.PerPage) //nolint:staticcheck + + return nil + } +} + +func (*queryOptions) Sort(sorter *query.Sorter) store.QueryOption { + return func(ctx context.Context) error { + if sorter.By == "" { + return nil + } + + query, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + query = query.OrderExpr("? ?", bun.Ident(sorter.By), bun.Safe(strings.ToUpper(sorter.Order))) //nolint:staticcheck + + return nil + } +} + +func (*queryOptions) Match(filters *query.Filters) store.QueryOption { + return func(ctx context.Context) error { + if len(filters.Data) < 1 { + return nil + } + + bunQuery, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + var filterErr error + bunQuery = bunQuery.WhereGroup("", func(q *bun.SelectQuery) *bun.SelectQuery { //nolint:staticcheck + currentOperator := "OR" //nolint:staticcheck + firstCondition := true + + for _, filter := range filters.Data { + switch filter.Type { + case query.FilterTypeOperator: + param, ok := filter.Params.(*query.FilterOperator) + if !ok { + return nil + } + + op, valid := internal.ParseFilterOperator(param) + if !valid { + continue + } + + currentOperator = op + case query.FilterTypeProperty: + param, ok := filter.Params.(*query.FilterProperty) + if !ok { + return nil + } + + condition, args, valid, err := internal.ParseFilterProperty(param) + if err != nil || !valid { + filterErr = err + + continue + } + + switch { + case firstCondition: // The first condition always applies a WHERE + q = q.Where(condition, args...) + firstCondition = false + case currentOperator == "AND": + q = q.Where(condition, args...) + case currentOperator == "OR": + q = q.WhereOr(condition, args...) + } + default: + return nil + } + } + + return q + }) + + if filterErr != nil { + return filterErr + } + + return nil + } +} + +func (*queryOptions) WithMember(userID string) store.QueryOption { + return func(ctx context.Context) error { + query, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + query = query.Where("EXISTS (SELECT 1 FROM memberships WHERE memberships.namespace_id = namespace.id AND memberships.user_id = ?)", userID) //nolint:staticcheck + + return nil + } +} + +func (*queryOptions) InNamespace(namespaceID string) store.QueryOption { + return func(ctx context.Context) error { + query, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + query = query.Where("namespace_id = ?", namespaceID) //nolint:staticcheck + + return nil + } +} + +func (*queryOptions) WithDeviceStatus(status models.DeviceStatus) store.QueryOption { + return func(ctx context.Context) error { + query, ok := ctx.Value("query").(*bun.SelectQuery) + if !ok { + return ErrQueryNotFound + } + + query = query.Where("status = ?", string(status)) //nolint:staticcheck + + return nil + } +} diff --git a/api/store/pg/session.go b/api/store/pg/session.go new file mode 100644 index 00000000000..3fcad1c3a98 --- /dev/null +++ b/api/store/pg/session.go @@ -0,0 +1,301 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" +) + +func (pg *Pg) SessionList(ctx context.Context, opts ...store.QueryOption) ([]models.Session, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.Session, 0) + query := db.NewSelect(). + Model(&entities). + Relation("Device"). + Relation("Device.Namespace") + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.Count(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + query = db.NewSelect(). + Model(&entities). + Relation("Device"). + Relation("Device.Namespace"). + ColumnExpr("sessions.*"). + ColumnExpr("CASE WHEN active_sessions.session_id IS NOT NULL THEN true ELSE false END as active"). + ColumnExpr("COALESCE(event_types.types, '') as event_types"). + ColumnExpr("COALESCE(event_seats.seats, '') as event_seats"). + Join("LEFT JOIN active_sessions ON sessions.id = active_sessions.session_id"). + Join(`LEFT JOIN ( + SELECT session_id, string_agg(DISTINCT type, ',') as types + FROM session_events + GROUP BY session_id + ) event_types ON sessions.id = event_types.session_id`). + Join(`LEFT JOIN ( + SELECT session_id, string_agg(DISTINCT seat::text, ',') as seats + FROM session_events + GROUP BY session_id + ) event_seats ON sessions.id = event_seats.session_id`) + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, 0, fromSQLError(err) + } + + sessions := make([]models.Session, len(entities)) + for i, e := range entities { + sessions[i] = *entity.SessionToModel(&e) + } + + return sessions, count, nil +} + +func (pg *Pg) SessionResolve(ctx context.Context, resolver store.SessionResolver, value string, opts ...store.QueryOption) (*models.Session, error) { + db := pg.getConnection(ctx) + + var sessionID string + switch resolver { + case store.SessionUIDResolver: + sessionID = value + default: + return nil, store.ErrNoDocuments + } + + e := &entity.Session{} + query := db.NewSelect(). + Model(e). + Relation("Device"). + Relation("Device.Namespace"). + ColumnExpr("sessions.*"). + ColumnExpr("CASE WHEN active_sessions.session_id IS NOT NULL THEN true ELSE false END as active"). + ColumnExpr("COALESCE(event_types.types, '') as event_types"). + ColumnExpr("COALESCE(event_seats.seats, '') as event_seats"). + Join("LEFT JOIN active_sessions ON sessions.id = active_sessions.session_id"). + Join(`LEFT JOIN ( + SELECT session_id, string_agg(DISTINCT type, ',') as types + FROM session_events + WHERE session_id = ? + GROUP BY session_id + ) event_types ON sessions.id = event_types.session_id`, sessionID). + Join(`LEFT JOIN ( + SELECT session_id, string_agg(DISTINCT seat::text, ',') as seats + FROM session_events + WHERE session_id = ? + GROUP BY session_id + ) event_seats ON sessions.id = event_seats.session_id`, sessionID). + Where("sessions.id = ?", sessionID) + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.SessionToModel(e), nil +} + +func (pg *Pg) SessionCreate(ctx context.Context, session models.Session) (string, error) { + db := pg.getConnection(ctx) + + session.StartedAt = clock.Now() + session.LastSeen = session.StartedAt + session.Recorded = false + + device, err := pg.DeviceResolve(ctx, store.DeviceUIDResolver, string(session.DeviceUID)) + if err != nil { + return "", fromSQLError(err) + } + + session.TenantID = device.TenantID + session.UID = uuid.Generate() + + e := entity.SessionFromModel(&session) + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return e.ID, nil +} + +func (pg *Pg) SessionUpdate(ctx context.Context, session *models.Session) error { + db := pg.getConnection(ctx) + + e := entity.SessionFromModel(session) + result, err := db.NewUpdate().Model(e).Where("id = ?", e.ID).Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fromSQLError(err) + } + + if rowsAffected < 1 { + return store.ErrNoDocuments + } + + return nil +} + +func (pg *Pg) ActiveSessionCreate(ctx context.Context, session *models.Session) error { + db := pg.getConnection(ctx) + + activeSession := &models.ActiveSession{UID: models.UID(session.UID), LastSeen: session.StartedAt, TenantID: session.TenantID} + e := entity.ActiveSessionFromModel(activeSession) + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) ActiveSessionResolve(ctx context.Context, resolver store.SessionResolver, value string) (*models.ActiveSession, error) { + db := pg.getConnection(ctx) + + var sessionID string + switch resolver { + case store.SessionUIDResolver: + sessionID = value + default: + return nil, store.ErrNoDocuments + } + + e := &entity.ActiveSession{} + if err := db.NewSelect().Model(e).Relation("Session").Relation("Session.Device").Where("session_id = ?", sessionID).Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.ActiveSessionToModel(e), nil +} + +func (pg *Pg) ActiveSessionUpdate(ctx context.Context, activeSession *models.ActiveSession) error { + db := pg.getConnection(ctx) + + e := entity.ActiveSessionFromModel(activeSession) + result, err := db.NewUpdate().Model(e).Where("session_id = ?", e.SessionID).Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fromSQLError(err) + } + + if rowsAffected < 1 { + return store.ErrNoDocuments + } + + return nil +} + +func (pg *Pg) ActiveSessionDelete(ctx context.Context, uid models.UID) error { + db := pg.getConnection(ctx) + + if _, err := db.NewDelete().Model((*entity.ActiveSession)(nil)).Where("session_id = ?", string(uid)).Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) SessionEventsCreate(ctx context.Context, event *models.SessionEvent) error { + db := pg.getConnection(ctx) + + e := entity.SessionEventFromModel(event) + e.ID = uuid.Generate() + + if _, err := db.NewInsert().Model(e).Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) SessionEventsList(ctx context.Context, uid models.UID, seat int, event models.SessionEventType, opts ...store.QueryOption) ([]models.SessionEvent, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.SessionEvent, 0) + query := db.NewSelect(). + Model(&entities). + Where("session_id = ?", string(uid)). + Where("seat = ?", seat). + Where("type = ?", string(event)). + Order("created_at ASC") + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.Count(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, 0, fromSQLError(err) + } + + events := make([]models.SessionEvent, len(entities)) + for i, e := range entities { + events[i] = *entity.SessionEventToModel(&e) + } + + return events, count, nil +} + +func (pg *Pg) SessionEventsDelete(ctx context.Context, uid models.UID, seat int, event models.SessionEventType) error { + db := pg.getConnection(ctx) + + if _, err := db.NewDelete(). + Model((*entity.SessionEvent)(nil)). + Where("session_id = ?", string(uid)). + Where("seat = ?", seat). + Where("type = ?", string(event)). + Exec(ctx); err != nil { + return fromSQLError(err) + } + + return nil +} + +func (pg *Pg) SessionUpdateDeviceUID(ctx context.Context, oldUID models.UID, newUID models.UID) error { + db := pg.getConnection(ctx) + + result, err := db.NewUpdate(). + Model((*entity.Session)(nil)). + Set("device_id = ?", string(newUID)). + Where("device_id = ?", string(oldUID)). + Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fromSQLError(err) + } + + if rowsAffected < 1 { + return store.ErrNoDocuments + } + + return nil +} diff --git a/api/store/pg/session_test.go b/api/store/pg/session_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/session_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/stats.go b/api/store/pg/stats.go new file mode 100644 index 00000000000..16600cf9741 --- /dev/null +++ b/api/store/pg/stats.go @@ -0,0 +1,117 @@ +package pg + +import ( + "context" + "time" + + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +func (pg *Pg) GetStats(ctx context.Context, tenantID string) (*models.Stats, error) { + db := pg.getConnection(ctx) + + onlineDevicesQuery := buildOnlineDevicesQuery(db, tenantID) + onlineDevices, err := onlineDevicesQuery.Count(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + registeredDevicesQuery := buildRegisteredDevicesQuery(db, tenantID) + registeredDevices, err := registeredDevicesQuery.Count(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + pendingDevicesQuery := buildPendingDevicesQuery(db, tenantID) + pendingDevices, err := pendingDevicesQuery.Count(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + rejectedDevicesQuery := buildRejectedDevicesQuery(db, tenantID) + rejectedDevices, err := rejectedDevicesQuery.Count(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + activeSessionsQuery := buildActiveSessionsQuery(db, tenantID) + activeSessions, err := activeSessionsQuery.Count(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + stats := &models.Stats{ + RegisteredDevices: registeredDevices, + OnlineDevices: onlineDevices, + PendingDevices: pendingDevices, + RejectedDevices: rejectedDevices, + ActiveSessions: activeSessions, + } + + return stats, nil +} + +func buildOnlineDevicesQuery(db bun.IDB, tenantID string) *bun.SelectQuery { + query := db.NewSelect(). + Model((*entity.Device)(nil)). + Where("disconnected_at IS NULL"). + Where("seen_at > ?", time.Now().Add(-2*time.Minute)). + Where("status = ?", "accepted") + + if tenantID != "" { + query = query.Where("namespace_id = (SELECT id FROM namespaces WHERE id = ?)", tenantID) + } + + return query +} + +func buildRegisteredDevicesQuery(db bun.IDB, tenantID string) *bun.SelectQuery { + query := db.NewSelect(). + Model((*entity.Device)(nil)). + Where("status = ?", "accepted") + + if tenantID != "" { + query = query.Where("namespace_id = (SELECT id FROM namespaces WHERE id = ?)", tenantID) + } + + return query +} + +func buildPendingDevicesQuery(db bun.IDB, tenantID string) *bun.SelectQuery { + query := db.NewSelect(). + Model((*entity.Device)(nil)). + Where("status = ?", "pending") + + if tenantID != "" { + query = query.Where("namespace_id = (SELECT id FROM namespaces WHERE id = ?)", tenantID) + } + + return query +} + +func buildRejectedDevicesQuery(db bun.IDB, tenantID string) *bun.SelectQuery { + query := db.NewSelect(). + Model((*entity.Device)(nil)). + Where("status = ?", "rejected") + + if tenantID != "" { + query = query.Where("namespace_id = (SELECT id FROM namespaces WHERE id = ?)", tenantID) + } + + return query +} + +func buildActiveSessionsQuery(db bun.IDB, tenantID string) *bun.SelectQuery { + query := db.NewSelect(). + Model((*entity.ActiveSession)(nil)). + Join("JOIN sessions ON active_sessions.session_id = sessions.id"). + Join("JOIN devices ON sessions.device_id = devices.id") + + if tenantID != "" { + query = query.Where("devices.namespace_id = (SELECT id FROM namespaces WHERE id = ?)", tenantID) + } + + return query +} diff --git a/api/store/pg/stats_test.go b/api/store/pg/stats_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/stats_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/system.go b/api/store/pg/system.go new file mode 100644 index 00000000000..181d1955533 --- /dev/null +++ b/api/store/pg/system.go @@ -0,0 +1,58 @@ +package pg + +import ( + "context" + "database/sql" + "errors" + + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" +) + +func (pg *Pg) SystemGet(ctx context.Context) (*models.System, error) { + db := pg.getConnection(ctx) + + system := new(entity.System) + if err := db.NewSelect().Model(system).Limit(1).Scan(ctx); err != nil { + if errors.Is(err, sql.ErrNoRows) { + system := &models.System{ + Setup: false, + Authentication: &models.SystemAuthentication{ + Local: &models.SystemAuthenticationLocal{ + Enabled: true, + }, + SAML: &models.SystemAuthenticationSAML{ + Enabled: false, + Idp: &models.SystemIdpSAML{Binding: &models.SystemAuthenticationBinding{}}, + Sp: &models.SystemSpSAML{}, + }, + }, + } + + return system, nil + } + + return nil, err + } + + return entity.SystemToModel(system), nil +} + +func (pg *Pg) SystemSet(ctx context.Context, system *models.System) error { + systemEntity := entity.SystemFromModel(system) + if systemEntity.ID == "" { + systemEntity.ID = uuid.Generate() + } + + db := pg.getConnection(ctx) + exists, err := db.NewSelect().Model((*entity.System)(nil)).Where("id = ?", systemEntity.ID).Exists(ctx) + switch { + case err == nil && !exists: + _, err = pg.driver.NewInsert().Model(systemEntity).Exec(ctx) + case err == nil && exists: + _, err = pg.driver.NewUpdate().Model(systemEntity).Where("id = ?", systemEntity.ID).Exec(ctx) + } + + return err +} diff --git a/api/store/pg/system_test.go b/api/store/pg/system_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/system_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/tag.go b/api/store/pg/tag.go new file mode 100644 index 00000000000..12273ffd95f --- /dev/null +++ b/api/store/pg/tag.go @@ -0,0 +1,207 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" + "github.com/uptrace/bun" +) + +func (pg *Pg) TagCreate(ctx context.Context, tag *models.Tag) (string, error) { + db := pg.getConnection(ctx) + + tag.CreatedAt = clock.Now() + tag.UpdatedAt = clock.Now() + + e := entity.TagFromModel(tag) + e.ID = uuid.Generate() + + var result entity.Tag + err := db.NewInsert(). + Model(e). + On("CONFLICT (namespace_id, name) DO UPDATE SET updated_at = EXCLUDED.updated_at"). + Returning("*"). + Scan(ctx, &result) + if err != nil { + return "", fromSQLError(err) + } + + return result.ID, nil +} + +func (pg *Pg) TagConflicts(ctx context.Context, tenantID string, target *models.TagConflicts) ([]string, bool, error) { + db := pg.getConnection(ctx) + + query := db.NewSelect().Model((*entity.Tag)(nil)).Column("name").Where("namespace_id = ?", tenantID) + if target.Name != "" { + query = query.Where("name = ?", target.Name) + } + + tags := make([]map[string]any, 0) + if err := query.Scan(ctx, &tags); err != nil { + return nil, false, fromSQLError(err) + } + + conflicts := make([]string, 0) + for _, tag := range tags { + if tag["name"] == target.Name { + conflicts = append(conflicts, "name") + } + } + + return conflicts, len(conflicts) > 0, nil +} + +func (pg *Pg) TagList(ctx context.Context, opts ...store.QueryOption) ([]models.Tag, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.Tag, 0) + query := db.NewSelect().Model(&entities).Column("tag.*") + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + tags := make([]models.Tag, len(entities)) + for i, e := range entities { + tags[i] = *entity.TagToModel(&e) + } + + return tags, count, nil +} + +func (pg *Pg) TagResolve(ctx context.Context, resolver store.TagResolver, value string, opts ...store.QueryOption) (*models.Tag, error) { + db := pg.getConnection(ctx) + + column, err := TagResolverToString(resolver) + if err != nil { + return nil, err + } + + tag := new(entity.Tag) + query := db.NewSelect().Model(tag).Column("tag.*").Relation("Namespace").Where("tag.? = ?", bun.Ident(column), value) + + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, fromSQLError(err) + } + + if err := query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.TagToModel(tag), nil +} + +func (pg *Pg) TagUpdate(ctx context.Context, tag *models.Tag) error { + db := pg.getConnection(ctx) + + t := entity.TagFromModel(tag) + t.UpdatedAt = clock.Now() + + r, err := db.NewUpdate().Model(t).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if count, err := r.RowsAffected(); err != nil || count == 0 { + return store.ErrNoDocuments + } + + return nil +} + +func (pg *Pg) TagPushToTarget(ctx context.Context, id string, target store.TagTarget, targetID string) error { + db := pg.getConnection(ctx) + + tag := new(entity.Tag) + if err := db.NewSelect().Model(tag).Where("id = ?", id).Scan(ctx); err != nil { + return fromSQLError(err) + } + + switch target { + case store.TagTargetDevice: + deviceTag := entity.NewDeviceTag(tag.ID, targetID) + deviceTag.CreatedAt = clock.Now() + + if _, err := db.NewInsert().Model(deviceTag).On("CONFLICT (device_id, tag_id) DO NOTHING").Exec(ctx); err != nil { + return fromSQLError(err) + } + case store.TagTargetPublicKey: + publickeyTag := entity.NewPublicKeyTag(tag.ID, targetID) + publickeyTag.CreatedAt = clock.Now() + + if _, err := db.NewInsert().Model(publickeyTag).On("CONFLICT (public_key_id, tag_id) DO NOTHING").Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + return nil +} + +func (pg *Pg) TagPullFromTarget(ctx context.Context, id string, target store.TagTarget, targetIDs ...string) error { + db := pg.getConnection(ctx) + + tag := new(entity.Tag) + if err := db.NewSelect().Model(tag).Where("id = ?", id).Scan(ctx); err != nil { + return fromSQLError(err) + } + + switch target { + case store.TagTargetDevice: + query := db.NewDelete().Model((*entity.DeviceTag)(nil)).Where("tag_id = ?", id) + if len(targetIDs) > 0 { + query = query.Where("device_id IN (?)", bun.In(targetIDs)) + } + + if _, err := query.Exec(ctx); err != nil { + return fromSQLError(err) + } + case store.TagTargetPublicKey: + query := db.NewDelete().Model((*entity.PublicKeyTag)(nil)).Where("tag_id = ?", id) + if len(targetIDs) > 0 { + query = query.Where("public_key_id IN (?)", bun.In(targetIDs)) + } + + if _, err := query.Exec(ctx); err != nil { + return fromSQLError(err) + } + } + + return nil +} + +func (pg *Pg) TagDelete(ctx context.Context, tag *models.Tag) error { + db := pg.getConnection(ctx) + + t := entity.TagFromModel(tag) + + r, err := db.NewDelete().Model(t).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return fromSQLError(err) +} + +func TagResolverToString(resolver store.TagResolver) (string, error) { + switch resolver { + case store.TagIDResolver: + return "id", nil + case store.TagNameResolver: + return "name", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/tag_test.go b/api/store/pg/tag_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/tag_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/transaction.go b/api/store/pg/transaction.go new file mode 100644 index 00000000000..11c58f23aaa --- /dev/null +++ b/api/store/pg/transaction.go @@ -0,0 +1,70 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + log "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +type txKeyType struct{} + +var txKey = txKeyType{} + +// getConnection returns the appropriate executor for the given context. +// If the context contains an active transaction, it returns the transaction handle. +// Otherwise, it returns the base database driver. +// +// This allows store methods to be written agnostic of whether they are +// running inside a transaction or not. +func (pg *Pg) getConnection(ctx context.Context) bun.IDB { + if tx, ok := ctx.Value(txKey).(bun.Tx); ok { + log.Debug("reusing existing SQL transaction from context") + + return tx + } + + return pg.driver +} + +// Example: +// +// err := store.WithTransaction(ctx, func(ctx context.Context) error { +// db := store.getExecutor(ctx) +// if _, err := db.NewDelete().Model(&Device{}).Where("id = ?", id).Exec(ctx); err != nil { +// return err +// } +// +// return store.NamespaceIncrementDeviceCount(ctx, tenantID, models.DeviceStatusRemoved, -1) +// }) +// +// TODO: The transaction handle is stored in the context for simplicity. +// This hides the dependency and makes it less explicit. +// Consider refactoring to expose a typed TxStore in the future for better clarity. +func (pg *Pg) WithTransaction(ctx context.Context, fn store.TransactionCb) (err error) { + tx, err := pg.driver.BeginTx(ctx, nil) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.WithError(rollbackErr).Error("transaction rollback failed after panic") + } + + panic(p) + } + }() + + if err := fn(context.WithValue(ctx, txKey, tx)); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.WithError(rollbackErr).Error("transaction rollback failed after error") + } + + return err + } + + return tx.Commit() +} diff --git a/api/store/pg/tunnel.go b/api/store/pg/tunnel.go new file mode 100644 index 00000000000..5715f91eb8e --- /dev/null +++ b/api/store/pg/tunnel.go @@ -0,0 +1,7 @@ +package pg + +import "context" + +func (pg *Pg) TunnelUpdateDeviceUID(ctx context.Context, tenantID, oldUID, newUID string) error { + return nil +} diff --git a/api/store/pg/tunnel_test.go b/api/store/pg/tunnel_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/tunnel_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/user.go b/api/store/pg/user.go new file mode 100644 index 00000000000..f32f2d06108 --- /dev/null +++ b/api/store/pg/user.go @@ -0,0 +1,164 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" + "github.com/uptrace/bun" +) + +func (pg *Pg) UserCreate(ctx context.Context, user *models.User) (string, error) { + db := pg.getConnection(ctx) + + user.ID = uuid.Generate() + user.CreatedAt = clock.Now() + + if _, err := db.NewInsert().Model(entity.UserFromModel(user)).Exec(ctx); err != nil { + return "", err + } + + return user.ID, nil +} + +func (pg *Pg) UserCreateInvited(ctx context.Context, email string) (string, error) { + return "", nil +} + +func (pg *Pg) UserConflicts(ctx context.Context, target *models.UserConflicts) ([]string, bool, error) { + db := pg.getConnection(ctx) + + users := make([]map[string]any, 0) + if err := db.NewSelect().Model((*entity.User)(nil)).Column("email").Where("email = ?", target.Email).Scan(ctx, &users); err != nil { + return nil, false, err + } + + conflicts := make([]string, 0) + for _, user := range users { + if user["email"] == target.Email { + conflicts = append(conflicts, "email") + } + } + + return conflicts, len(conflicts) > 0, nil +} + +func (pg *Pg) UserList(ctx context.Context, opts ...store.QueryOption) ([]models.User, int, error) { + db := pg.getConnection(ctx) + + entities := make([]entity.User, 0) + query := db.NewSelect().Model(&entities) + if err := applyOptions(ctx, query, opts...); err != nil { + return nil, 0, fromSQLError(err) + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + users := make([]models.User, len(entities)) + for i, e := range entities { + users[i] = *entity.UserToModel(&e) + } + + return users, count, nil +} + +func (pg *Pg) UserResolve(ctx context.Context, resolver store.UserResolver, val string, opts ...store.QueryOption) (*models.User, error) { + db := pg.getConnection(ctx) + + column, err := UserResolverToString(resolver) + if err != nil { + return nil, err + } + + u := new(entity.User) + if err := db.NewSelect().Model(u).Where("? = ?", bun.Ident(column), val).Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.UserToModel(u), nil +} + +func (pg *Pg) UserGetInfo(ctx context.Context, userID string) (userInfo *models.UserInfo, err error) { + db := pg.getConnection(ctx) + + var namespaceEntities []entity.Namespace + err = db.NewSelect(). + Model(&namespaceEntities). + Relation("Memberships.User"). + Where("owner_id = ? OR EXISTS (SELECT 1 FROM memberships WHERE memberships.namespace_id = namespace.id AND memberships.user_id = ?)", userID, userID). + Scan(ctx) + if err != nil { + return nil, fromSQLError(err) + } + + userInfo = &models.UserInfo{ + OwnedNamespaces: make([]models.Namespace, 0), + AssociatedNamespaces: make([]models.Namespace, 0), + } + + for _, nsEntity := range namespaceEntities { + ns := entity.NamespaceToModel(&nsEntity) + + if nsEntity.OwnerID == userID { + userInfo.OwnedNamespaces = append(userInfo.OwnedNamespaces, *ns) + } else { + userInfo.AssociatedNamespaces = append(userInfo.AssociatedNamespaces, *ns) + } + } + + return userInfo, nil +} + +func (pg *Pg) UserUpdate(ctx context.Context, user *models.User) error { + db := pg.getConnection(ctx) + + u := entity.UserFromModel(user) + u.UpdatedAt = clock.Now() + + r, err := db.NewUpdate().Model(u).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return fromSQLError(err) +} + +func (pg *Pg) UserDelete(ctx context.Context, user *models.User) error { + db := pg.getConnection(ctx) + + u := entity.UserFromModel(user) + + r, err := db.NewDelete().Model(u).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return fromSQLError(err) +} + +func UserResolverToString(resolver store.UserResolver) (string, error) { + switch resolver { + case store.UserIDResolver: + return "id", nil + case store.UserEmailResolver: + return "email", nil + case store.UserUsernameResolver: + return "username", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/user_test.go b/api/store/pg/user_test.go new file mode 100644 index 00000000000..e1e51799ddd --- /dev/null +++ b/api/store/pg/user_test.go @@ -0,0 +1 @@ +package pg_test diff --git a/api/store/pg/utils.go b/api/store/pg/utils.go new file mode 100644 index 00000000000..6d778366096 --- /dev/null +++ b/api/store/pg/utils.go @@ -0,0 +1,32 @@ +package pg + +import ( + "context" + "database/sql" + "io" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/uptrace/bun" +) + +func fromSQLError(err error) error { + switch err { + case nil: + return nil + case sql.ErrNoRows, io.EOF: + return store.ErrNoDocuments + default: + return err + } +} + +func applyOptions(ctx context.Context, query *bun.SelectQuery, opts ...store.QueryOption) error { + ctxWithQuery := context.WithValue(ctx, "query", query) + for _, opt := range opts { + if err := opt(ctxWithQuery); err != nil { + return fromSQLError(err) + } + } + + return nil +} diff --git a/bin/utils b/bin/utils index 124bfa4962e..c29180141c7 100644 --- a/bin/utils +++ b/bin/utils @@ -71,6 +71,18 @@ COMPOSE_FILE="docker-compose.yml" [ "$SHELLHUB_ENTERPRISE" = "true" ] && [ "$SHELLHUB_ENV" != "development" ] && COMPOSE_FILE="${COMPOSE_FILE}:docker-compose.enterprise.yml" [ -f docker-compose.override.yml ] && COMPOSE_FILE="${COMPOSE_FILE}:docker-compose.override.yml" +SHELLHUB_DATABASE=${SHELLHUB_DATABASE:-mongo} +case "$SHELLHUB_DATABASE" in + mongo|postgres) + COMPOSE_FILE="${COMPOSE_FILE}:docker-compose.${SHELLHUB_DATABASE}.yml" + ;; + *) + echo "⚠️ WARNING: Unknown SHELLHUB_DATABASE '$SHELLHUB_DATABASE'. Defaulting to mongodb." + SHELLHUB_DATABASE="mongo" + COMPOSE_FILE="${COMPOSE_FILE}:docker-compose.mongo.yml" + ;; +esac + [ -f "$EXTRA_COMPOSE_FILE" ] && COMPOSE_FILE="${COMPOSE_FILE}:${EXTRA_COMPOSE_FILE}" export COMPOSE_FILE diff --git a/cli/go.mod b/cli/go.mod index 6045ed9488a..eb2559c5ef4 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -31,6 +31,11 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.6 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/pgzip v1.2.5 // indirect github.com/labstack/echo/v4 v4.13.4 // indirect @@ -41,21 +46,26 @@ require ( github.com/mholt/archiver/v4 v4.0.0-alpha.8 // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/nwaples/rardecode/v2 v2.2.0 // indirect + github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11 // indirect github.com/oschwald/geoip2-golang v1.8.0 // indirect github.com/oschwald/maxminddb-golang v1.10.0 // indirect github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/sethvargo/go-envconfig v0.9.0 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/square/mongo-lock v0.0.0-20230808145049-cfcf499f6bf0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/therootcompany/xz v1.0.1 // indirect + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/ulikunitz/xz v0.5.14 // indirect + github.com/uptrace/bun v1.2.15 // indirect + github.com/uptrace/bun/dialect/pgdialect v1.2.15 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect - github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xakep666/mongo-migrate v0.3.2 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect @@ -67,7 +77,7 @@ require ( golang.org/x/crypto v0.43.0 // indirect golang.org/x/net v0.45.0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.37.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/cli/go.sum b/cli/go.sum index 51fc4de7d74..5326d87cfb9 100644 --- a/cli/go.sum +++ b/cli/go.sum @@ -167,6 +167,16 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -224,6 +234,8 @@ github.com/nwaples/rardecode/v2 v2.2.0/go.mod h1:7uz379lSxPe6j9nvzxUZ+n7mnJNgjsR github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11 h1:rAqW9sGcM0VsfBwgeBzHk0yebrRwfeSJFy9Egqi0fmM= +github.com/oiime/logrusbun v0.1.2-0.20241011112815-4df3a0fb0e11/go.mod h1:HH9akx9teKgQPX41TYpLLRNxaL8q9R+ltzABnwUHfBM= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= @@ -251,6 +263,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= +github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= @@ -264,6 +278,7 @@ github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742 h1:sIFW1zdZv github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742/go.mod h1:6J6yfW5oIvAZ6VjxmV9KyFZyPFVM3B4V3Epbb+1c0oo= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= @@ -277,6 +292,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -290,6 +307,8 @@ github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0 h1:z/1qHeliTLDKNaJ7uOHOx1FjwghbcbYfga4dTFkF0hU= github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0/go.mod h1:GaunAWwMXLtsMKG3xn2HYIBDbKddGArfcGsF2Aog81E= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ= github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw= github.com/therootcompany/xz v1.0.1/go.mod h1:3K3UH1yCKgBneZYhuQUvJ9HPD19UEXEI0BWbMn8qNMY= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= @@ -301,9 +320,16 @@ github.com/tkuchiki/go-timezone v0.2.2 h1:MdHR65KwgVTwWFQrota4SKzc4L5EfuH5SdZZGt github.com/tkuchiki/go-timezone v0.2.2/go.mod h1:oFweWxYl35C/s7HMVZXiA19Jr9Y0qJHMaG/J2TES4LY= github.com/tkuchiki/parsetime v0.3.0 h1:cvblFQlPeAPJL8g6MgIGCHnnmHSZvluuY+hexoZCNqc= github.com/tkuchiki/parsetime v0.3.0/go.mod h1:OJkQmIrf5Ao7R+WYIdITPOfDVj8LmnHGCfQ8DTs3LCA= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.14 h1:uv/0Bq533iFdnMHZdRBTOlaNMdb1+ZxXIlHDZHIHcvg= github.com/ulikunitz/xz v0.5.14/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/uptrace/bun v0.3.9/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= +github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE= +github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038= +github.com/uptrace/bun/dialect/pgdialect v1.2.15 h1:er+/3giAIqpfrXJw+KP9B7ujyQIi5XkPnFmgjAVL6bA= +github.com/uptrace/bun/dialect/pgdialect v1.2.15/go.mod h1:QSiz6Qpy9wlGFsfpf7UMSL6mXAL1jDJhFwuOVacCnOQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= @@ -311,8 +337,8 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/vmihailenco/go-tinylfu v0.2.2 h1:H1eiG6HM36iniK6+21n9LLpzx1G9R3DJa2UjUjbynsI= github.com/vmihailenco/go-tinylfu v0.2.2/go.mod h1:CutYi2Q9puTxfcolkliPq4npPuofg9N9t8JVrjzwa3Q= github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= -github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xakep666/mongo-migrate v0.3.2 h1:qmDtIGiMRIwMvc84fOlsDoP+08S6NWLJDPqa4wPfQ1U= @@ -451,12 +477,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/cli/main.go b/cli/main.go index 075dbea77c2..de8508aeffd 100644 --- a/cli/main.go +++ b/cli/main.go @@ -3,7 +3,10 @@ package main import ( "context" + "github.com/shellhub-io/shellhub/api/store" "github.com/shellhub-io/shellhub/api/store/mongo" + "github.com/shellhub-io/shellhub/api/store/pg" + pgoptions "github.com/shellhub-io/shellhub/api/store/pg/options" "github.com/shellhub-io/shellhub/cli/cmd" "github.com/shellhub-io/shellhub/cli/services" "github.com/shellhub-io/shellhub/pkg/cache" @@ -14,7 +17,21 @@ import ( ) type config struct { + Database string `env:"DATABASE,default=mongo"` + MongoURI string `env:"MONGO_URI,default=mongodb://mongo:27017/main"` + + // PostgresHost specifies the host for PostgreSQL. + PostgresHost string `env:"POSTGRES_HOST,default=postgres"` + // PostgresPort specifies the port for PostgreSQL. + PostgresPort string `env:"POSTGRES_PORT,default=5432"` + // PostgresUsername specifies the username for authenticate PostgreSQL. + PostgresUsername string `env:"POSTGRES_USERNAME,default=admin"` + // PostgresUser specifies the password for authenticate PostgreSQL. + PostgresPassword string `env:"POSTGRES_PASSWORD,default=admin"` + // PostgresDatabase especifica o nome do banco de dados PostgreSQL a ser utilizado. + PostgresDatabase string `env:"POSTGRES_DATABASE,default=main"` + RedisURI string `env:"REDIS_URI,default=redis://redis:6379"` } @@ -41,11 +58,19 @@ func main() { log.Trace("Connecting to MongoDB") - store, err := mongo.NewStore(ctx, cfg.MongoURI, cache) + var store store.Store + switch cfg.Database { + case "mongo": + store, err = mongo.NewStore(ctx, cfg.MongoURI, cache) + case "postgres": + uri := pg.URI(cfg.PostgresHost, cfg.PostgresPort, cfg.PostgresUsername, cfg.PostgresPassword, cfg.PostgresDatabase) + store, err = pg.New(ctx, uri, pgoptions.Log("INFO", true)) // TODO: Log envs + default: + log.WithField("database", cfg.Database).Fatal("invalid database") + } + if err != nil { - log. - WithError(err). - Fatal("failed to create the store") + log.WithError(err).Fatal("failed to create the store") } service := services.NewService(store) diff --git a/docker-compose.mongo.test.yml b/docker-compose.mongo.test.yml new file mode 100644 index 00000000000..c57fbceb0cf --- /dev/null +++ b/docker-compose.mongo.test.yml @@ -0,0 +1,15 @@ +services: + mongo: + image: mongo:4.4.29 + restart: unless-stopped + healthcheck: + test: 'test $$(echo "rs.initiate({ _id: ''rs'', members: [ { _id: 0, host: ''mongo:27017'' } ] }).ok || rs.status().ok" | mongo --quiet) -eq 1' + interval: 30s + start_period: 10s + command: ["--replSet", "rs", "--bind_ip_all"] + networks: + - shellhub + api: + depends_on: + mongo: + condition: service_healthy diff --git a/docker-compose.mongo.yml b/docker-compose.mongo.yml new file mode 100644 index 00000000000..c57fbceb0cf --- /dev/null +++ b/docker-compose.mongo.yml @@ -0,0 +1,15 @@ +services: + mongo: + image: mongo:4.4.29 + restart: unless-stopped + healthcheck: + test: 'test $$(echo "rs.initiate({ _id: ''rs'', members: [ { _id: 0, host: ''mongo:27017'' } ] }).ok || rs.status().ok" | mongo --quiet) -eq 1' + interval: 30s + start_period: 10s + command: ["--replSet", "rs", "--bind_ip_all"] + networks: + - shellhub + api: + depends_on: + mongo: + condition: service_healthy diff --git a/docker-compose.postgres.test.yml b/docker-compose.postgres.test.yml new file mode 100644 index 00000000000..8ce73e17bd2 --- /dev/null +++ b/docker-compose.postgres.test.yml @@ -0,0 +1,24 @@ +services: + postgres: + image: postgres:18.0 + command: postgres -c io_method=io_uring + security_opt: + # Disable seccomp to allow io_uring syscalls (io_uring_setup, io_uring_enter, io_uring_register) + # which are blocked by Docker's default seccomp profile for security reasons. + - seccomp=unconfined + healthcheck: + start_period: 90s + interval: 5s + timeout: 5s + retries: 5 + test: ["CMD-SHELL", "pg_isready -U ${SHELLHUB_POSTGRES_USERNAME} -d ${SHELLHUB_POSTGRES_DATABASE}"] + environment: + - POSTGRES_USER=${SHELLHUB_POSTGRES_USERNAME} + - POSTGRES_PASSWORD=${SHELLHUB_POSTGRES_PASSWORD} + - POSTGRES_DB=${SHELLHUB_POSTGRES_DATABASE} + networks: + - shellhub + api: + depends_on: + postgres: + condition: service_healthy diff --git a/docker-compose.postgres.yml b/docker-compose.postgres.yml new file mode 100644 index 00000000000..8ce73e17bd2 --- /dev/null +++ b/docker-compose.postgres.yml @@ -0,0 +1,24 @@ +services: + postgres: + image: postgres:18.0 + command: postgres -c io_method=io_uring + security_opt: + # Disable seccomp to allow io_uring syscalls (io_uring_setup, io_uring_enter, io_uring_register) + # which are blocked by Docker's default seccomp profile for security reasons. + - seccomp=unconfined + healthcheck: + start_period: 90s + interval: 5s + timeout: 5s + retries: 5 + test: ["CMD-SHELL", "pg_isready -U ${SHELLHUB_POSTGRES_USERNAME} -d ${SHELLHUB_POSTGRES_DATABASE}"] + environment: + - POSTGRES_USER=${SHELLHUB_POSTGRES_USERNAME} + - POSTGRES_PASSWORD=${SHELLHUB_POSTGRES_PASSWORD} + - POSTGRES_DB=${SHELLHUB_POSTGRES_DATABASE} + networks: + - shellhub + api: + depends_on: + postgres: + condition: service_healthy diff --git a/docker-compose.test.yml b/docker-compose.test.yml index e2efe1fbd28..3d08de33770 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -34,9 +34,3 @@ services: start_period: 10s retries: 20 ports: [] - mongo: - healthcheck: - interval: 5s - start_period: 10s - retries: 20 - ports: [] diff --git a/docker-compose.yml b/docker-compose.yml index 062441c6996..e29ae15f428 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -33,10 +33,16 @@ services: restart: unless-stopped environment: - SHELLHUB_VERSION=${SHELLHUB_VERSION} + - DATABASE=${SHELLHUB_DATABASE} - PRIVATE_KEY=/run/secrets/api_private_key - PUBLIC_KEY=/run/secrets/api_public_key - SHELLHUB_ENTERPRISE=${SHELLHUB_ENTERPRISE} - SHELLHUB_CLOUD=${SHELLHUB_CLOUD} + - POSTGRES_HOST=${SHELLHUB_POSTGRES_HOST} + - POSTGRES_PORT=${SHELLHUB_POSTGRES_PORT} + - POSTGRES_USERNAME=${SHELLHUB_POSTGRES_USERNAME} + - POSTGRES_PASSWORD=${SHELLHUB_POSTGRES_PASSWORD} + - POSTGRES_DATABASE=${SHELLHUB_POSTGRES_DATABASE} - MAXMIND_MIRROR=${SHELLHUB_MAXMIND_MIRROR} - MAXMIND_LICENSE=${SHELLHUB_MAXMIND_LICENSE} - TELEMETRY=${SHELLHUB_TELEMETRY:-} @@ -60,10 +66,8 @@ services: - SHELLHUB_INTERNAL_HTTP_CLIENT_API_BASE_URL=${SHELLHUB_INTERNAL_HTTP_CLIENT_API_BASE_URL} - SHELLHUB_INTERNAL_HTTP_CLIENT_ENTERPRISE_BASE_URL=${SHELLHUB_INTERNAL_HTTP_CLIENT_ENTERPRISE_BASE_URL} depends_on: - - mongo - redis links: - - mongo - redis secrets: - api_private_key @@ -141,18 +145,14 @@ services: stop_signal: SIGKILL command: /bin/sleep infinity environment: + - DATABASE=${SHELLHUB_DATABASE} - SHELLHUB_LOG_LEVEL=${SHELLHUB_LOG_LEVEL} - SHELLHUB_LOG_FORMAT=${SHELLHUB_LOG_FORMAT} - networks: - - shellhub - mongo: - image: mongo:4.4.29 - restart: unless-stopped - healthcheck: - test: 'test $$(echo "rs.initiate({ _id: ''rs'', members: [ { _id: 0, host: ''mongo:27017'' } ] }).ok || rs.status().ok" | mongo --quiet) -eq 1' - interval: 30s - start_period: 10s - command: ["--replSet", "rs", "--bind_ip_all"] + - POSTGRES_HOST=${SHELLHUB_POSTGRES_HOST} + - POSTGRES_PORT=${SHELLHUB_POSTGRES_PORT} + - POSTGRES_USERNAME=${SHELLHUB_POSTGRES_USERNAME} + - POSTGRES_PASSWORD=${SHELLHUB_POSTGRES_PASSWORD} + - POSTGRES_DATABASE=${SHELLHUB_POSTGRES_DATABASE} networks: - shellhub redis: diff --git a/tests/environment/configurator.go b/tests/environment/configurator.go index d179817c770..ff4ece11a17 100644 --- a/tests/environment/configurator.go +++ b/tests/environment/configurator.go @@ -96,10 +96,15 @@ func (dcc *DockerComposeConfigurator) Up(ctx context.Context) *DockerCompose { down: nil, } - tcDc, err := compose.NewDockerComposeWith( - compose.WithStackFiles("../docker-compose.yml", "../docker-compose.test.yml"), - compose.WithLogger(log.New(io.Discard, "", log.LstdFlags)), - ) + dockerFiles := []string{"../docker-compose.yml", "../docker-compose.test.yml"} + switch dc.envs["SHELLHUB_DATABASE"] { + case "postgres": + dockerFiles = append(dockerFiles, "../docker-compose.postgres.test.yml") + default: + dockerFiles = append(dockerFiles, "../docker-compose.mongo.test.yml") + } + + tcDc, err := compose.NewDockerComposeWith(compose.WithStackFiles(dockerFiles...), compose.WithLogger(log.New(io.Discard, "", log.LstdFlags))) if !assert.NoError(dcc.t, err) { assert.FailNow(dcc.t, err.Error()) } diff --git a/tests/ssh_test.go b/tests/ssh_test.go index 072b349047b..4e2ae60403a 100644 --- a/tests/ssh_test.go +++ b/tests/ssh_test.go @@ -1324,87 +1324,88 @@ func TestSSH(t *testing.T) { ctx := context.Background() - compose := environment.New(t).Up(ctx) - t.Cleanup(func() { - compose.Down() - }) + databases := []string{"mongo", "postgres"} + for _, db := range databases { + compose := environment.New(t).WithEnv("SHELLHUB_DATABASE", db).Up(ctx) + compose.NewUser(t, ShellHubUsername, ShellHubEmail, ShellHubPassword) + compose.NewNamespace(t, ShellHubUsername, ShellHubNamespaceName, ShellHubNamespace) - compose.NewUser(t, ShellHubUsername, ShellHubEmail, ShellHubPassword) - compose.NewNamespace(t, ShellHubUsername, ShellHubNamespaceName, ShellHubNamespace) - - auth := models.UserAuthResponse{} - - require.EventuallyWithT(t, func(tt *assert.CollectT) { - resp, err := compose.R(ctx). - SetBody(map[string]string{ - "username": ShellHubUsername, - "password": ShellHubPassword, - }). - SetResult(&auth). - Post("/api/login") - assert.Equal(tt, 200, resp.StatusCode()) - assert.NoError(tt, err) - }, 30*time.Second, 1*time.Second) - - compose.JWT(auth.Token) - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - agent, err := NewAgentContainer( - ctx, - compose.Env("SHELLHUB_HTTP_PORT"), - test.options..., - ) - require.NoError(tt, err) - - agent.Stop(ctx, nil) - - err = agent.Start(ctx) - require.NoError(tt, err) - - tt.Cleanup(func() { - agent.Stop(context.Background(), nil) - }) - - t.Cleanup(func() { - agent.Terminate(context.Background()) - }) + auth := models.UserAuthResponse{} - devices := []models.Device{} + require.EventuallyWithT(t, func(tt *assert.CollectT) { + resp, err := compose.R(ctx). + SetBody(map[string]string{ + "username": ShellHubUsername, + "password": ShellHubPassword, + }). + SetResult(&auth). + Post("/api/login") + assert.Equal(tt, 200, resp.StatusCode()) + assert.NoError(tt, err) + }, 30*time.Second, 1*time.Second) + + compose.JWT(auth.Token) + + for _, tc := range tests { + test := tc + t.Run(db+" "+test.name, func(tt *testing.T) { + agent, err := NewAgentContainer( + ctx, + compose.Env("SHELLHUB_HTTP_PORT"), + test.options..., + ) + require.NoError(tt, err) + + agent.Stop(ctx, nil) + + err = agent.Start(ctx) + require.NoError(tt, err) + + tt.Cleanup(func() { + agent.Stop(context.Background(), nil) + }) - require.EventuallyWithT(tt, func(tt *assert.CollectT) { - resp, err := compose.R(ctx).SetResult(&devices). - Get("/api/devices?status=pending") - assert.Equal(tt, 200, resp.StatusCode()) - assert.NoError(tt, err) + t.Cleanup(func() { + agent.Terminate(context.Background()) + }) - assert.Len(tt, devices, 1) - }, 30*time.Second, 1*time.Second) + devices := []models.Device{} - resp, err := compose.R(ctx). - Patch(fmt.Sprintf("/api/devices/%s/accept", devices[0].UID)) - require.Equal(tt, 200, resp.StatusCode()) - require.NoError(tt, err) + require.EventuallyWithT(tt, func(tt *assert.CollectT) { + resp, err := compose.R(ctx).SetResult(&devices). + Get("/api/devices?status=pending") + assert.Equal(tt, 200, resp.StatusCode()) + assert.NoError(tt, err) - device := models.Device{} + assert.Len(tt, devices, 1) + }, 30*time.Second, 1*time.Second) - require.EventuallyWithT(tt, func(tt *assert.CollectT) { resp, err := compose.R(ctx). - SetResult(&device). - Get(fmt.Sprintf("/api/devices/%s", devices[0].UID)) - assert.Equal(tt, 200, resp.StatusCode()) - assert.NoError(tt, err) + Patch(fmt.Sprintf("/api/devices/%s/accept", devices[0].UID)) + require.Equal(tt, 200, resp.StatusCode()) + require.NoError(tt, err) + + device := models.Device{} + + require.EventuallyWithT(tt, func(tt *assert.CollectT) { + resp, err := compose.R(ctx). + SetResult(&device). + Get(fmt.Sprintf("/api/devices/%s", devices[0].UID)) + assert.Equal(tt, 200, resp.StatusCode()) + assert.NoError(tt, err) - assert.True(tt, device.Online) - }, 30*time.Second, 1*time.Second) + assert.True(tt, device.Online) + }, 30*time.Second, 1*time.Second) - // -- + // -- - test.run(tt, &Environment{ - services: compose, - agent: agent, - }, &device) - }) + test.run(tt, &Environment{ + services: compose, + agent: agent, + }, &device) + }) + } + + compose.Down() } }