Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate equals and hashcode for @ConfigMapping implementations #1181

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.smallrye.config;

import static org.objectweb.asm.Opcodes.AASTORE;
import static org.objectweb.asm.Opcodes.ACC_ABSTRACT;
import static org.objectweb.asm.Opcodes.ACC_FINAL;
import static org.objectweb.asm.Opcodes.ACC_INTERFACE;
Expand All @@ -8,22 +9,39 @@
import static org.objectweb.asm.Opcodes.ACC_STATIC;
import static org.objectweb.asm.Opcodes.ACONST_NULL;
import static org.objectweb.asm.Opcodes.ALOAD;
import static org.objectweb.asm.Opcodes.ANEWARRAY;
import static org.objectweb.asm.Opcodes.ARETURN;
import static org.objectweb.asm.Opcodes.ASM7;
import static org.objectweb.asm.Opcodes.ASTORE;
import static org.objectweb.asm.Opcodes.BIPUSH;
import static org.objectweb.asm.Opcodes.CHECKCAST;
import static org.objectweb.asm.Opcodes.DCMPL;
import static org.objectweb.asm.Opcodes.DRETURN;
import static org.objectweb.asm.Opcodes.DUP;
import static org.objectweb.asm.Opcodes.FCMPL;
import static org.objectweb.asm.Opcodes.FRETURN;
import static org.objectweb.asm.Opcodes.F_SAME;
import static org.objectweb.asm.Opcodes.GETFIELD;
import static org.objectweb.asm.Opcodes.GETSTATIC;
import static org.objectweb.asm.Opcodes.GOTO;
import static org.objectweb.asm.Opcodes.I2C;
import static org.objectweb.asm.Opcodes.ICONST_0;
import static org.objectweb.asm.Opcodes.ICONST_1;
import static org.objectweb.asm.Opcodes.IFEQ;
import static org.objectweb.asm.Opcodes.IFNE;
import static org.objectweb.asm.Opcodes.IFNULL;
import static org.objectweb.asm.Opcodes.IF_ACMPEQ;
import static org.objectweb.asm.Opcodes.IF_ACMPNE;
import static org.objectweb.asm.Opcodes.IF_ICMPNE;
import static org.objectweb.asm.Opcodes.ILOAD;
import static org.objectweb.asm.Opcodes.INVOKEINTERFACE;
import static org.objectweb.asm.Opcodes.INVOKESPECIAL;
import static org.objectweb.asm.Opcodes.INVOKESTATIC;
import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
import static org.objectweb.asm.Opcodes.IRETURN;
import static org.objectweb.asm.Opcodes.ISTORE;
import static org.objectweb.asm.Opcodes.LCMP;
import static org.objectweb.asm.Opcodes.LRETURN;
import static org.objectweb.asm.Opcodes.NEW;
import static org.objectweb.asm.Opcodes.POP;
import static org.objectweb.asm.Opcodes.PUTFIELD;
Expand Down Expand Up @@ -53,7 +71,6 @@
import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import io.smallrye.config.ConfigMapping.NamingStrategy;
Expand Down Expand Up @@ -160,12 +177,6 @@ static byte[] generate(final ConfigMappingInterface mapping) {
ctor.visitVarInsn(ASTORE, V_NAMING_STRATEGY);

addProperties(visitor, ctor, new HashSet<>(), mapping, mapping.getClassInternalName());
if (mapping.getToStringMethod().generate()) {
addToString(visitor, mapping);
}

generateNames(visitor, mapping);
generateDefaults(visitor, mapping);

ctor.visitInsn(RETURN);
ctor.visitLabel(ctorEnd);
Expand All @@ -178,6 +189,12 @@ static byte[] generate(final ConfigMappingInterface mapping) {
ctor.visitMaxs(0, 0);
visitor.visitEnd();

generateEquals(visitor, mapping);
generateHashCode(visitor, mapping);
generateToString(visitor, mapping);
generateNames(visitor, mapping);
generateDefaults(visitor, mapping);

return writer.toByteArray();
}

Expand Down Expand Up @@ -680,13 +697,13 @@ private static int getReturnInstruction(Property property) {
}

if (primitiveProperty.getPrimitiveType() == float.class) {
return Opcodes.FRETURN;
return FRETURN;
} else if (primitiveProperty.getPrimitiveType() == double.class) {
return Opcodes.DRETURN;
return DRETURN;
} else if (primitiveProperty.getPrimitiveType() == long.class) {
return Opcodes.LRETURN;
return LRETURN;
} else {
return Opcodes.IRETURN;
return IRETURN;
}
}

Expand All @@ -707,7 +724,11 @@ private static String getSignature(final Field field) {
return null;
}

private static void addToString(final ClassVisitor visitor, final ConfigMappingInterface mapping) {
private static void generateToString(final ClassVisitor visitor, final ConfigMappingInterface mapping) {
if (!mapping.getToStringMethod().generate()) {
return;
}

MethodVisitor ts = visitor.visitMethod(ACC_PUBLIC, "toString", "()L" + I_STRING + ";", null, null);
ts.visitCode();
ts.visitTypeInsn(NEW, I_STRING_BUILDER);
Expand Down Expand Up @@ -754,6 +775,131 @@ private static void addToString(final ClassVisitor visitor, final ConfigMappingI
ts.visitMaxs(0, 0);
}

private static void generateEquals(final ClassVisitor visitor, final ConfigMappingInterface mapping) {
MethodVisitor eq = visitor.visitMethod(ACC_PUBLIC, "equals", "(Ljava/lang/Object;)Z", null, null);
eq.visitCode();
int V_O = 1;
int V_THAT = 2;

// if (this == o) {return true;}
eq.visitVarInsn(ALOAD, V_THIS);
eq.visitVarInsn(ALOAD, V_O);
Label _ifRef = new Label();
eq.visitJumpInsn(IF_ACMPNE, _ifRef);
eq.visitInsn(ICONST_1);
eq.visitInsn(IRETURN);
eq.visitLabel(_ifRef);

// if (o == null || getClass() != o.getClass()) {return false;}
eq.visitVarInsn(ALOAD, V_O);
Label _ifNull = new Label();
eq.visitJumpInsn(IFNULL, _ifNull);
eq.visitVarInsn(ALOAD, V_THIS);
Label _ifClass = new Label();
eq.visitMethodInsn(INVOKEVIRTUAL, I_OBJECT, "getClass", "()L" + I_CLASS + ";", false);
eq.visitVarInsn(ALOAD, V_O);
eq.visitMethodInsn(INVOKEVIRTUAL, I_OBJECT, "getClass", "()L" + I_CLASS + ";", false);
eq.visitJumpInsn(IF_ACMPEQ, _ifClass);
eq.visitLabel(_ifNull);
eq.visitFrame(F_SAME, 0, null, 0, null);
eq.visitInsn(ICONST_0);
eq.visitInsn(IRETURN);
eq.visitLabel(_ifClass);

// ConfigMappingClass that = (ConfigMappingClass) o;
eq.visitVarInsn(ALOAD, V_O);
eq.visitTypeInsn(CHECKCAST, mapping.getClassInternalName());
eq.visitVarInsn(ASTORE, V_THAT);

// this.primitive() == that.primitive() && this.object().equals(that.object()) ...
Label _ifTrue = new Label();
Label _ifFalse = new Label();
for (Property property : mapping.getProperties()) {
// unwrap Kotlin default methods
if (property.isDefaultMethod()) {
property = property.asDefaultMethod().getDefaultProperty();
}

String member = property.getMethod().getName();
Class<?> returnType = property.getMethod().getReturnType();

eq.visitVarInsn(ALOAD, V_THIS);
eq.visitMethodInsn(INVOKEVIRTUAL, mapping.getClassInternalName(), member, "()" + getDescriptor(returnType), false);
eq.visitVarInsn(ALOAD, V_THAT);
eq.visitMethodInsn(INVOKEVIRTUAL, mapping.getClassInternalName(), member, "()" + getDescriptor(returnType), false);
if (property.isPrimitive()) {
PrimitiveProperty primitiveProperty = property.asPrimitive();
if (primitiveProperty.getPrimitiveType() == float.class) {
eq.visitInsn(FCMPL);
eq.visitJumpInsn(IFNE, _ifFalse);
} else if (primitiveProperty.getPrimitiveType() == double.class) {
eq.visitInsn(DCMPL);
eq.visitJumpInsn(IFNE, _ifFalse);
} else if (primitiveProperty.getPrimitiveType() == long.class) {
eq.visitInsn(LCMP);
eq.visitJumpInsn(IFNE, _ifFalse);
} else {
eq.visitJumpInsn(IF_ICMPNE, _ifFalse);
}
} else {
eq.visitMethodInsn(INVOKESTATIC, "java/util/Objects", "equals", "(Ljava/lang/Object;Ljava/lang/Object;)Z",
false);
eq.visitJumpInsn(IFEQ, _ifFalse);
}
}

// return
eq.visitInsn(ICONST_1);
eq.visitJumpInsn(GOTO, _ifTrue);
eq.visitLabel(_ifFalse);
eq.visitInsn(ICONST_0);
eq.visitLabel(_ifTrue);
eq.visitInsn(IRETURN);

eq.visitEnd();
eq.visitMaxs(0, 0);
}

private static void generateHashCode(final ClassVisitor visitor, final ConfigMappingInterface mapping) {
MethodVisitor hc = visitor.visitMethod(ACC_PUBLIC, "hashCode", "()I", null, null);
hc.visitCode();
Property[] properties = mapping.getProperties();

hc.visitIntInsn(BIPUSH, properties.length);
hc.visitTypeInsn(ANEWARRAY, I_OBJECT);
hc.visitInsn(DUP);

for (int i = 0; i < properties.length; i++) {
Property property = properties[i];
// unwrap Kotlin default methods
if (property.isDefaultMethod()) {
property = property.asDefaultMethod().getDefaultProperty();
}

String member = property.getMethod().getName();
Class<?> returnType = property.getMethod().getReturnType();

hc.visitIntInsn(BIPUSH, i);
hc.visitVarInsn(ALOAD, V_THIS);
hc.visitFieldInsn(GETFIELD, mapping.getClassInternalName(), member, getDescriptor(returnType));
if (property.isPrimitive()) {
PrimitiveProperty primitiveProperty = property.asPrimitive();
hc.visitMethodInsn(INVOKESTATIC, getInternalName(primitiveProperty.getBoxType()), "valueOf",
"(" + getDescriptor(primitiveProperty.getPrimitiveType()) + ")"
+ getDescriptor(primitiveProperty.getBoxType()),
false);
}
hc.visitInsn(AASTORE);
hc.visitInsn(DUP);
}

hc.visitMethodInsn(INVOKESTATIC, "java/util/Objects", "hash", "([Ljava/lang/Object;)I", false);
hc.visitInsn(IRETURN);

hc.visitMaxs(0, 0);
hc.visitEnd();
}

private static void generateNames(final ClassVisitor classVisitor, final ConfigMappingInterface mapping) {
MethodVisitor mv = classVisitor.visitMethod(ACC_PUBLIC | ACC_STATIC, "getNames", "()Ljava/util/Map;",
"()Ljava/util/Map<Ljava/lang/String;Ljava/util/Map<Ljava/lang/String;Ljava/util/Set<Ljava/lang/String;>;>;>;",
Expand Down Expand Up @@ -971,7 +1117,7 @@ static final class ClassVisitorImpl extends ClassVisitor {
}

ClassVisitorImpl(final ClassWriter cw) {
super(Opcodes.ASM7, cw);
super(ASM7, cw);
sourceFile = getCaller().getFileName();
}

Expand Down
Loading