@@ -17,6 +17,15 @@ namespace hlsl {
17
17
18
18
using TokenKind = RootSignatureToken::Kind;
19
19
20
+ static const TokenKind RootElementKeywords[] = {
21
+ TokenKind::kw_RootFlags,
22
+ TokenKind::kw_CBV,
23
+ TokenKind::kw_UAV,
24
+ TokenKind::kw_SRV,
25
+ TokenKind::kw_DescriptorTable,
26
+ TokenKind::kw_StaticSampler,
27
+ };
28
+
20
29
RootSignatureParser::RootSignatureParser (
21
30
llvm::dxbc::RootSignatureVersion Version,
22
31
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
@@ -27,51 +36,76 @@ RootSignatureParser::RootSignatureParser(
27
36
bool RootSignatureParser::parse () {
28
37
// Iterate as many RootSignatureElements as possible, until we hit the
29
38
// end of the stream
39
+ bool HadError = false ;
30
40
while (!peekExpectedToken (TokenKind::end_of_stream)) {
31
41
if (tryConsumeExpectedToken (TokenKind::kw_RootFlags)) {
32
42
SourceLocation ElementLoc = getTokenLocation (CurToken);
33
43
auto Flags = parseRootFlags ();
34
- if (!Flags.has_value ())
35
- return true ;
44
+ if (!Flags.has_value ()) {
45
+ HadError = true ;
46
+ skipUntilExpectedToken (RootElementKeywords);
47
+ continue ;
48
+ }
49
+
36
50
Elements.emplace_back (ElementLoc, *Flags);
37
51
} else if (tryConsumeExpectedToken (TokenKind::kw_RootConstants)) {
38
52
SourceLocation ElementLoc = getTokenLocation (CurToken);
39
53
auto Constants = parseRootConstants ();
40
- if (!Constants.has_value ())
41
- return true ;
54
+ if (!Constants.has_value ()) {
55
+ HadError = true ;
56
+ skipUntilExpectedToken (RootElementKeywords);
57
+ continue ;
58
+ }
42
59
Elements.emplace_back (ElementLoc, *Constants);
43
60
} else if (tryConsumeExpectedToken (TokenKind::kw_DescriptorTable)) {
44
61
SourceLocation ElementLoc = getTokenLocation (CurToken);
45
62
auto Table = parseDescriptorTable ();
46
- if (!Table.has_value ())
47
- return true ;
63
+ if (!Table.has_value ()) {
64
+ HadError = true ;
65
+ // We are within a DescriptorTable, we will do our best to recover
66
+ // by skipping until we encounter the expected closing ')'.
67
+ skipUntilClosedParens ();
68
+ consumeNextToken ();
69
+ skipUntilExpectedToken (RootElementKeywords);
70
+ continue ;
71
+ }
48
72
Elements.emplace_back (ElementLoc, *Table);
49
73
} else if (tryConsumeExpectedToken (
50
74
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
51
75
SourceLocation ElementLoc = getTokenLocation (CurToken);
52
76
auto Descriptor = parseRootDescriptor ();
53
- if (!Descriptor.has_value ())
54
- return true ;
77
+ if (!Descriptor.has_value ()) {
78
+ HadError = true ;
79
+ skipUntilExpectedToken (RootElementKeywords);
80
+ continue ;
81
+ }
55
82
Elements.emplace_back (ElementLoc, *Descriptor);
56
83
} else if (tryConsumeExpectedToken (TokenKind::kw_StaticSampler)) {
57
84
SourceLocation ElementLoc = getTokenLocation (CurToken);
58
85
auto Sampler = parseStaticSampler ();
59
- if (!Sampler.has_value ())
60
- return true ;
86
+ if (!Sampler.has_value ()) {
87
+ HadError = true ;
88
+ skipUntilExpectedToken (RootElementKeywords);
89
+ continue ;
90
+ }
61
91
Elements.emplace_back (ElementLoc, *Sampler);
62
92
} else {
93
+ HadError = true ;
63
94
consumeNextToken (); // let diagnostic be at the start of invalid token
64
95
reportDiag (diag::err_hlsl_invalid_token)
65
96
<< /* parameter=*/ 0 << /* param of*/ TokenKind::kw_RootSignature;
66
- return true ;
97
+ skipUntilExpectedToken (RootElementKeywords);
98
+ continue ;
67
99
}
68
100
69
- // ',' denotes another element, otherwise, expected to be at end of stream
70
- if (! tryConsumeExpectedToken (TokenKind::pu_comma))
101
+ if (! tryConsumeExpectedToken (TokenKind::pu_comma)) {
102
+ // ',' denotes another element, otherwise, expected to be at end of stream
71
103
break ;
104
+ }
72
105
}
73
106
74
- return consumeExpectedToken (TokenKind::end_of_stream,
107
+ return HadError ||
108
+ consumeExpectedToken (TokenKind::end_of_stream,
75
109
diag::err_expected_either, TokenKind::pu_comma);
76
110
}
77
111
@@ -262,8 +296,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
262
296
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
263
297
SourceLocation ElementLoc = getTokenLocation (CurToken);
264
298
auto Clause = parseDescriptorTableClause ();
265
- if (!Clause.has_value ())
299
+ if (!Clause.has_value ()) {
300
+ // We are within a DescriptorTableClause, we will do our best to recover
301
+ // by skipping until we encounter the expected closing ')'
302
+ skipUntilExpectedToken (TokenKind::pu_r_paren);
303
+ consumeNextToken ();
266
304
return std::nullopt;
305
+ }
267
306
Elements.emplace_back (ElementLoc, *Clause);
268
307
Table.NumClauses ++;
269
308
} else if (tryConsumeExpectedToken (TokenKind::kw_visibility)) {
@@ -1371,6 +1410,40 @@ bool RootSignatureParser::tryConsumeExpectedToken(
1371
1410
return true ;
1372
1411
}
1373
1412
1413
+ bool RootSignatureParser::skipUntilExpectedToken (TokenKind Expected) {
1414
+ return skipUntilExpectedToken (ArrayRef{Expected});
1415
+ }
1416
+
1417
+ bool RootSignatureParser::skipUntilExpectedToken (
1418
+ ArrayRef<TokenKind> AnyExpected) {
1419
+
1420
+ while (!peekExpectedToken (AnyExpected)) {
1421
+ if (peekExpectedToken (TokenKind::end_of_stream))
1422
+ return false ;
1423
+ consumeNextToken ();
1424
+ }
1425
+
1426
+ return true ;
1427
+ }
1428
+
1429
+ bool RootSignatureParser::skipUntilClosedParens (uint32_t NumParens) {
1430
+ TokenKind ParenKinds[] = {
1431
+ TokenKind::pu_l_paren,
1432
+ TokenKind::pu_r_paren,
1433
+ };
1434
+ while (skipUntilExpectedToken (ParenKinds)) {
1435
+ consumeNextToken ();
1436
+ if (CurToken.TokKind == TokenKind::pu_r_paren)
1437
+ NumParens--;
1438
+ else
1439
+ NumParens++;
1440
+ if (NumParens == 0 )
1441
+ return true ;
1442
+ }
1443
+
1444
+ return false ;
1445
+ }
1446
+
1374
1447
SourceLocation RootSignatureParser::getTokenLocation (RootSignatureToken Tok) {
1375
1448
return Signature->getLocationOfByte (Tok.LocOffset , PP.getSourceManager (),
1376
1449
PP.getLangOpts (), PP.getTargetInfo ());
0 commit comments