@@ -11,9 +11,13 @@ import type {
11
11
StickyOffsets ,
12
12
} from '../interface' ;
13
13
import HeaderRow from './HeaderRow' ;
14
+ import cls from 'classnames' ;
15
+ import { TableProps } from '..' ;
14
16
15
17
function parseHeaderRows < RecordType > (
16
18
rootColumns : ColumnsType < RecordType > ,
19
+ classNames : TableProps [ 'classNames' ] [ 'header' ] ,
20
+ styles : TableProps [ 'styles' ] [ 'header' ] ,
17
21
) : CellType < RecordType > [ ] [ ] {
18
22
const rows : CellType < RecordType > [ ] [ ] = [ ] ;
19
23
@@ -29,7 +33,8 @@ function parseHeaderRows<RecordType>(
29
33
const colSpans : number [ ] = columns . filter ( Boolean ) . map ( column => {
30
34
const cell : CellType < RecordType > = {
31
35
key : column . key ,
32
- className : column . className || '' ,
36
+ className : cls ( column . className , classNames . cell ) || '' ,
37
+ style : styles . cell ,
33
38
children : column . title ,
34
39
column,
35
40
colStart : currentColIndex ,
@@ -97,18 +102,33 @@ const Header = <RecordType extends any>(props: HeaderProps<RecordType>) => {
97
102
98
103
const { stickyOffsets, columns, flattenColumns, onHeaderRow } = props ;
99
104
100
- const { prefixCls, getComponent } = useContext ( TableContext , [ 'prefixCls' , 'getComponent' ] ) ;
101
- const rows = React . useMemo < CellType < RecordType > [ ] [ ] > ( ( ) => parseHeaderRows ( columns ) , [ columns ] ) ;
105
+ const { prefixCls, getComponent, classNames, styles } = useContext ( TableContext , [
106
+ 'prefixCls' ,
107
+ 'getComponent' ,
108
+ 'classNames' ,
109
+ 'styles' ,
110
+ ] ) ;
111
+ const { header : headerCls = { } } = classNames || { } ;
112
+ const { header : headerStyles = { } } = styles || { } ;
113
+ const rows = React . useMemo < CellType < RecordType > [ ] [ ] > (
114
+ ( ) => parseHeaderRows ( columns , headerCls , headerStyles ) ,
115
+ [ columns , headerCls , headerStyles ] ,
116
+ ) ;
102
117
103
118
const WrapperComponent = getComponent ( [ 'header' , 'wrapper' ] , 'thead' ) ;
104
119
const trComponent = getComponent ( [ 'header' , 'row' ] , 'tr' ) ;
105
120
const thComponent = getComponent ( [ 'header' , 'cell' ] , 'th' ) ;
106
121
107
122
return (
108
- < WrapperComponent className = { `${ prefixCls } -thead` } >
123
+ < WrapperComponent
124
+ className = { cls ( `${ prefixCls } -thead` , headerCls . wrapper ) }
125
+ style = { headerStyles . wrapper }
126
+ >
109
127
{ rows . map ( ( row , rowIndex ) => {
110
128
const rowNode = (
111
129
< HeaderRow
130
+ classNames = { headerCls }
131
+ styles = { headerStyles }
112
132
key = { rowIndex }
113
133
flattenColumns = { flattenColumns }
114
134
cells = { row }
0 commit comments