[Chapter13&Inheritance] - Adding the inheritance logic

This commit is contained in:
Adnan Ioricce 2024-10-06 12:18:30 +00:00
parent 05a8ac6127
commit 2117e9ed07
7 changed files with 241 additions and 100 deletions

@ -7,9 +7,13 @@ abstract class Expr {
R visitAssignExpr(Assign expr);
R visitBinaryExpr(Binary expr);
R visitCallExpr(Call expr);
R visitGetExpr(Get expr);
R visitGroupingExpr(Grouping expr);
R visitLiteralExpr(Literal expr);
R visitLogicalExpr(Logical expr);
R visitSetExpr(Set expr);
R visitSuperExpr(Super expr);
R visitThisExpr(This expr);
R visitUnaryExpr(Unary expr);
R visitVariableExpr(Variable expr);
}
@ -59,6 +63,20 @@ abstract class Expr {
final Token paren;
final List<Expr> arguments;
}
static class Get extends Expr {
Get(Expr object, Token name) {
this.object = object;
this.name = name;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitGetExpr(this);
}
final Expr object;
final Token name;
}
static class Grouping extends Expr {
Grouping(Expr expression) {
this.expression = expression;
@ -99,6 +117,48 @@ abstract class Expr {
final Token operator;
final Expr right;
}
static class Set extends Expr {
Set(Expr object, Token name, Expr value) {
this.object = object;
this.name = name;
this.value = value;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitSetExpr(this);
}
final Expr object;
final Token name;
final Expr value;
}
static class Super extends Expr {
Super(Token keyword, Token method) {
this.keyword = keyword;
this.method = method;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitSuperExpr(this);
}
final Token keyword;
final Token method;
}
static class This extends Expr {
This(Token keyword) {
this.keyword = keyword;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitThisExpr(this);
}
final Token keyword;
}
static class Unary extends Expr {
Unary(Token operator, Expr right) {
this.operator = operator;

@ -39,108 +39,118 @@ class Interpreter implements Expr.Visitor<Object>,Stmt.Visitor<Void> {
return evaluate(expr.right);
}
@Override
public Object visitSetExpr(Expr.Set expr) {
Object object = evaluate(expr.object);
public Object visitSetExpr(Expr.Set expr) {
Object object = evaluate(expr.object);
if (!(object instanceof LoxInstance)) {
throw new RuntimeError(expr.name,
"Only instances have fields.");
}
Object value = evaluate(expr.value);
((LoxInstance)object).set(expr.name, value);
return value;
}
@Override
public Object visitThisExpr(Expr.This expr) {
return lookUpVariable(expr.keyword, expr);
}
@Override
public Object visitUnaryExpr(Expr.Unary expr) {
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case BANG:
return !isTruthy(right);
case MINUS:
checkNumberOperand(expr.operator, right);
return -(double)right;
}
// Unreachable.
return null;
}
@Override
public Object visitVariableExpr(Expr.Variable expr) {
return lookUpVariable(expr.name, expr);
}
private Object lookUpVariable(Token name, Expr expr) {
Integer distance = locals.get(expr);
if (distance != null) {
return environment.getAt(distance, name.lexeme);
} else {
return globals.get(name);
}
}
private void checkNumberOperand(Token operator, Object operand) {
if (operand instanceof Double) return;
throw new RuntimeError(operator, "Operand must be a number.");
}
private void checkNumberOperands(Token operator,
Object left, Object right) {
if (left instanceof Double && right instanceof Double) return;
throw new RuntimeError(operator, "Operands must be numbers.");
if (!(object instanceof LoxInstance)) {
throw new RuntimeError(expr.name,"Only instances have fields.");
}
@Override
public Object visitGroupingExpr(Expr.Grouping expr) {
return evaluate(expr.expression);
Object value = evaluate(expr.value);
((LoxInstance)object).set(expr.name, value);
return value;
}
@Override
public Object visitSuperExpr(Expr.Super expr) {
int distance = locals.get(expr);
LoxInstance object = (LoxInstance)environment.getAt(
distance - 1, "this");
LoxFunction method = superclass.findMethod(expr.method.lexeme);
if (method == null) {
throw new RuntimeError(expr.method,
"Undefined property '" + expr.method.lexeme + "'.");
}
@Override
public Object visitBinaryExpr(Expr.Binary expr) {
Object left = evaluate(expr.left);
Object right = evaluate(expr.right);
return method.bind(object);
}
switch (expr.operator.type) {
case GREATER:
checkNumberOperands(expr.operator, left, right);
return (double)left > (double)right;
case GREATER_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double)left >= (double)right;
case LESS:
checkNumberOperands(expr.operator, left, right);
return (double)left < (double)right;
case LESS_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double)left <= (double)right;
case MINUS:
checkNumberOperands(expr.operator, left, right);
return (double)left - (double)right;
case PLUS:
if (left instanceof Double && right instanceof Double) {
return (double)left + (double)right;
}
@Override
public Object visitThisExpr(Expr.This expr) {
return lookUpVariable(expr.keyword, expr);
}
@Override
public Object visitUnaryExpr(Expr.Unary expr) {
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case BANG:
return !isTruthy(right);
case MINUS:
checkNumberOperand(expr.operator, right);
return -(double)right;
}
if (left instanceof String && right instanceof String) {
return (String)left + (String)right;
}
throw new RuntimeError(expr.operator,
// Unreachable.
return null;
}
@Override
public Object visitVariableExpr(Expr.Variable expr) {
return lookUpVariable(expr.name, expr);
}
private Object lookUpVariable(Token name, Expr expr) {
Integer distance = locals.get(expr);
if (distance != null) {
return environment.getAt(distance, name.lexeme);
} else {
return globals.get(name);
}
}
private void checkNumberOperand(Token operator, Object operand) {
if (operand instanceof Double) return;
throw new RuntimeError(operator, "Operand must be a number.");
}
private void checkNumberOperands(Token operator,Object left, Object right) {
if (left instanceof Double && right instanceof Double) return;
throw new RuntimeError(operator, "Operands must be numbers.");
}
@Override
public Object visitGroupingExpr(Expr.Grouping expr) {
return evaluate(expr.expression);
}
@Override
public Object visitBinaryExpr(Expr.Binary expr) {
Object left = evaluate(expr.left);
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case GREATER:
checkNumberOperands(expr.operator, left, right);
return (double)left > (double)right;
case GREATER_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double)left >= (double)right;
case LESS:
checkNumberOperands(expr.operator, left, right);
return (double)left < (double)right;
case LESS_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double)left <= (double)right;
case MINUS:
checkNumberOperands(expr.operator, left, right);
return (double)left - (double)right;
case PLUS:
if (left instanceof Double && right instanceof Double) {
return (double)left + (double)right;
}
if (left instanceof String && right instanceof String) {
return (String)left + (String)right;
}
throw new RuntimeError(expr.operator,
"Operands must be two numbers or two strings.");
case SLASH:
checkNumberOperands(expr.operator, left, right);
return (double)left / (double)right;
case STAR:
checkNumberOperands(expr.operator, left, right);
return (double)left * (double)right;
case BANG_EQUAL: return !isEqual(left, right);
case EQUAL_EQUAL: return isEqual(left, right);
}
case SLASH:
checkNumberOperands(expr.operator, left, right);
return (double)left / (double)right;
case STAR:
checkNumberOperands(expr.operator, left, right);
return (double)left * (double)right;
case BANG_EQUAL: return !isEqual(left, right);
case EQUAL_EQUAL: return isEqual(left, right);
}
// Unreachable.
return null;
}
}
@Override
public Object visitCallExpr(Expr.Call expr){
Object callee = evaluate(expr.callee);
@ -281,7 +291,19 @@ class Interpreter implements Expr.Visitor<Object>,Stmt.Visitor<Void> {
}
@Override
public void visitClassStmt(Stmt.Class stmt){
Object superclass = null;
if (stmt.superclass != null) {
superclass = evaluate(stmt.superclass);
if (!(superclass instanceof LoxClass)) {
throw new RuntimeError(stmt.superclass.name,
"Superclass must be a class.");
}
}
environment.define(stmt.name.lexeme, null);
if (stmt.superclass != null) {
environment = new Environment(environment);
environment.define("super", superclass);
}
Map<String, LoxFunction> methods = new HashMap<>();
for (Stmt.Function method : stmt.methods) {
LoxFunction function = new LoxFunction(method, environment,
@ -289,7 +311,10 @@ class Interpreter implements Expr.Visitor<Object>,Stmt.Visitor<Void> {
methods.put(method.name.lexeme, function);
}
LoxClass klass = new LoxClass(stmt.name.lexeme, methods);
LoxClass klass = new LoxClass(stmt.name.lexeme, methods,(LoxClass)superclass,methods);
if (superclass != null) {
environment = environment.enclosing;
}
environment.assign(stmt.name, klass);
return null;
}

@ -5,9 +5,11 @@ import java.util.Map;
class LoxClass implements LoxCallable {
final String name;
final LoxClass superclass;
private final Map<String, LoxFunction> methods;
LoxClass(String name, Map<String, LoxFunction> methods) {
LoxClass(String name, LoxClass superclass, Map<String, LoxFunction> methods) {
this.superclass = superclass;
this.name = name;
this.methods = methods;
}
@ -15,7 +17,9 @@ class LoxClass implements LoxCallable {
if (methods.containsKey(name)) {
return methods.get(name);
}
if (superclass != null) {
return superclass.findMethod(name);
}
return null;
}
@Override

@ -33,6 +33,11 @@ class Parser {
}
private Stmt classDeclaration(){
Token name = consume(IDENTIFIER, "Expect class name.");
Expr.Variable superclass = null;
if(match(LESS)){
consume(IDENTIFIER, "Expect superclass name.");
superclass = new Expr.Variable(previous());
}
consume(LEFT_BRACE, "Expect '{' before class body");
List<Stmt.Function> methods = new ArrayList<>();
@ -41,8 +46,7 @@ class Parser {
}
consume(RIGHT_BRACE, "Expect '}' after class body.");
return new Stmt.Class(name,)
return new Stmt.Class(name, superclass, methods);
}
private Expr expression(){
return assignment();
@ -322,7 +326,13 @@ class Parser {
if (match(IDENTIFIER)) {
return new Expr.Variable(previous());
}
if (match(SUPER)) {
Token keyword = previous();
consume(DOT, "Expect '.' after 'super'.");
Token method = consume(IDENTIFIER,
"Expect superclass method name.");
return new Expr.Super(keyword, method);
}
if (match(THIS)) return new Expr.This(previous());
throw error(peek(), "Expect expression.");

@ -21,7 +21,8 @@ class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
}
private enum ClassType {
NONE,
CLASS
CLASS,
SUBCLASS
}
private ClassType currentClass = ClassType.NONE;
@ -57,6 +58,19 @@ class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
declare(stmt.name);
define(stmt.name);
if (stmt.superclass != null &&
stmt.name.lexeme.equals(stmt.superclass.name.lexeme)) {
Lox.error(stmt.superclass.name,
"A class can't inherit from itself.");
}
if (stmt.superclass != null) {
currentClass = ClassType.SUBCLASS;
resolve(stmt.superclass);
}
if (stmt.superclass != null) {
beginScope();
scopes.peek().put("super", true);
}
beginScope();
scopes.peek().put("this", true);
for (Stmt.Function method : stmt.methods) {
@ -67,7 +81,7 @@ class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
resolveFunction(method, declaration);
}
endScope();
if (stmt.superclass != null) endScope();
currentClass = enclosingClass;
return null;
}
@ -196,6 +210,18 @@ class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
return null;
}
@Override
public Void visitSuperExpr(Expr.Super expr) {
if (currentClass == ClassType.NONE) {
Lox.error(expr.keyword,
"Can't use 'super' outside of a class.");
} else if (currentClass != ClassType.SUBCLASS) {
Lox.error(expr.keyword,
"Can't use 'super' in a class with no superclass.");
}
resolveLocal(expr, expr.keyword);
return null;
}
@Override
public Void visitThisExpr(Expr.This expr) {
if (currentClass == ClassType.NONE) {
Lox.error(expr.keyword,

@ -5,6 +5,7 @@ import java.util.List;
abstract class Stmt {
interface Visitor<R> {
R visitBlockStmt(Block stmt);
R visitClassStmt(Class stmt);
R visitIfStmt(If stmt);
R visitExpressionStmt(Expression stmt);
R visitFunctionStmt(Function stmt);
@ -25,6 +26,20 @@ abstract class Stmt {
final List<Stmt> statements;
}
static class Class extends Stmt {
Class(Token name, List<Stmt.Function> methods) {
this.name = name;
this.methods = methods;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitClassStmt(this);
}
final Token name;
final List<Stmt.Function> methods;
}
static class If extends Stmt {
If(Expr condition, Stmt thenBranch, Stmt elseBranch) {
this.condition = condition;

@ -21,6 +21,7 @@ public class GenerateAst {
"Literal : Object value",
"Logical : Expr left, Token operator, Expr right",
"Set : Expr object, Token name, Expr value",
"Super : Token keyword, Token method",
"This : Token keyword",
"Unary : Token operator, Expr right",
"Variable : Token name"