Skip to content

Commit 4d70592

Browse files
committed
'0407'
1 parent 60992cd commit 4d70592

File tree

7 files changed

+976
-1634
lines changed

7 files changed

+976
-1634
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,16 +452,16 @@ sphere = Sphere(nInput = 10)
452452
from UQPyL.DoE import LHS
453453

454454
# Generate 200 training samples in the input space using LHS
455-
lhs = LHS(problem)
456-
xTrain = lhs.sample(200, problem.nInput)
455+
lhs = LHS()
456+
xTrain = lhs.sample(200, problem.nInput, problem = sphere)
457457

458458
# Evaluate the true objective function at training points
459-
yTrain = problem.objFunc(xTrain)
459+
yTrain = sphere.objFunc(xTrain)
460460

461461
# Generate 50 test samples for model validation
462-
xTest = lhs.sample(50, problem.nInput)
462+
xTest = lhs.sample(50, problem.nInput, problem = sphere)
463463
# Evaluate the true function at test points
464-
yTest = problem.evaluate(xTest)
464+
yTest = sphere.objFunc(xTest)
465465

466466
# Import Radial Basis Function (RBF) surrogate model
467467
from UQPyL.surrogate.rbf import RBF

README_CN.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,17 +455,17 @@ sphere = Sphere(nInput = 10)
455455
from UQPyL.DoE import LHS
456456

457457
# 使用 LHS 方法生成 200 个训练样本
458-
lhs = LHS(problem)
458+
lhs = LHS()
459459
xTrain = lhs.sample(200, problem.nInput)
460460

461461
# 计算训练样本的目标函数值
462-
yTrain = problem.objFunc(xTrain)
462+
yTrain = sphere.objFunc(xTrain)
463463

464464
# 使用相同方法生成 50 个测试样本
465465
xTest = lhs.sample(50, problem.nInput)
466466

467467
# 计算测试样本的真实目标值
468-
yTest = problem.evaluate(xTest)
468+
yTest = sphere.objFunc(xTest)
469469

470470
# 从surrogate模块导入 RBF 替代模型
471471
from UQPyL.surrogate.rbf import RBF
@@ -482,6 +482,7 @@ yPred = rbf.predict(xTest)
482482
# 导入 R² 评估指标
483483
from UQPyL.utility.metric import r_square
484484
# 计算预测结果的 R² 分数,衡量模型拟合效果
485+
485486
r2 = r_square(yTest, yPred)
486487
# 输出 R² 分数
487488
print(r2)

0 commit comments

Comments
 (0)